diff --git a/Jenkinsfile b/Jenkinsfile index 7680bf9..63bd83f 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -11,14 +11,27 @@ pipeline { choice( name: 'ATTACK_MODE', choices: [ - 'ALL_SEQUENTIAL', - 'NONE', - 'POISONING', - 'STEALTHY', - 'SYBIL_SIMULATION', - 'GRADIENT_INVERSION' - ], - description: 'Attack scenario to test', + 'ALL_SEQUENTIAL', + 'NONE', + 'POISONING', + 'MODEL_POISONING', + 'DATA_POISONING', + 'BACKDOOR', + 'LABEL_FLIP', + 'STEALTHY', + 'SYBIL_SIMULATION', + 'GRADIENT_INVERSION', + 'MEMBERSHIP_INFERENCE', + 'PROPERTY_INFERENCE', + 'MODEL_REPLACEMENT', + 'MALICIOUS_AGGREGATION', + 'ADVERSARIAL_EXAMPLES', + 'MODEL_DRIFT', + 'FREE_RIDING', + 'SIGN_FLIP', + 'GAUSSIAN_NOISE' + ], + description: 'Attack scenario to test' ) string( @@ -132,7 +145,7 @@ pipeline { exit 0 fi echo "Waiting for server... ($i/12)" - sleep 5 + sleep 15 done echo "Server failed to start" exit 1 @@ -151,7 +164,25 @@ pipeline { if (params.ATTACK_MODE == 'ALL_SEQUENTIAL') { // Iterate through all actual attack types - def attacks = ['POISONING', 'STEALTHY', 'SYBIL_SIMULATION', 'GRADIENT_INVERSION'] + def attacks = [ + 'POISONING', + 'MODEL_POISONING', + 'DATA_POISONING', + 'BACKDOOR', + 'LABEL_FLIP', + 'STEALTHY', + 'SYBIL_SIMULATION', + 'GRADIENT_INVERSION', + 'MEMBERSHIP_INFERENCE', + 'PROPERTY_INFERENCE', + 'MODEL_REPLACEMENT', + 'MALICIOUS_AGGREGATION', + 'ADVERSARIAL_EXAMPLES', + 'MODEL_DRIFT', + 'FREE_RIDING', + 'SIGN_FLIP', + 'GAUSSIAN_NOISE' + ]; for (attack in attacks) { runAttackScenario(attack) } @@ -265,12 +296,15 @@ def runAttackScenario(String attackMode) { # 1. Update the attack mode in the compose file sed -i.bak "s/ATTACK_MODE=.*/ATTACK_MODE=${attackMode}/" infra/docker/docker-compose-app.yml - # 2. Recreate ONLY the malicious client. + # 2. Force clean malicious client for a new attack + docker rm -f fl_malicious_client || true + + # 3. Recreate ONLY the malicious client. docker compose -f infra/docker/docker-compose-app.yml up -d --no-deps --force-recreate malicious_client - sleep 5 + sleep 20 - # 4. FIX: Use ${params.FL_ROUNDS} so Jenkins actually passes the value + # 4. Trigger training docker exec fl_server python3 utils/control.py --mode train --rounds ${params.FL_ROUNDS} --wait ${params.FL_WAIT} echo " Attack scenario ${attackMode} completed" diff --git a/fl-project/infra/monitoring/grafana/provisioning/alerting/fl-alerting.yml b/fl-project/infra/monitoring/grafana/provisioning/alerting/fl-alerting.yml index 88ca8ae..6c9e86f 100644 --- a/fl-project/infra/monitoring/grafana/provisioning/alerting/fl-alerting.yml +++ b/fl-project/infra/monitoring/grafana/provisioning/alerting/fl-alerting.yml @@ -1,75 +1,75 @@ -apiVersion: 1 +# apiVersion: 1 -groups: - - orgId: 1 - name: 'FL-Security-Alerts' - folder: 'Federated Learning' - interval: 10s - rules: - - uid: 'alert-attack-detected' - title: 'Active Attack Detected' - condition: 'B' - data: - - refId: 'A' - relativeTimeRange: - from: 300 - to: 0 - datasourceUid: 'P8E80F9AEF21F6940' - model: - expr: 'count_over_time({job="fl_server"} | json | event_type="ATTACK_DETECTED" [1m])' - refId: 'A' - - refId: 'B' - datasourceUid: '__expr__' - model: - conditions: [ { evaluator: { params: [ 0 ], type: 'gt' }, operator: { type: 'and' }, query: { params: [ 'A' ] }, reducer: { params: [], type: 'sum' }, type: 'query' } ] - datasource: { type: '__expr__', uid: '__expr__' } - expression: 'A' - type: 'threshold' - noDataState: 'OK' - for: 0s - annotations: - summary: 'Attack detected on FL Server' - description: 'The security module detected a malicious update attempting to poison the model.' - labels: - severity: 'critical' - category: 'security' +# groups: +# - orgId: 1 +# name: 'FL-Security-Alerts' +# folder: 'Federated Learning' +# interval: 10s +# rules: +# - uid: 'alert-attack-detected' +# title: 'Active Attack Detected' +# condition: 'B' +# data: +# - refId: 'A' +# relativeTimeRange: +# from: 300 +# to: 0 +# datasourceUid: 'P8E80F9AEF21F6940' +# model: +# expr: 'count_over_time({job="fl_server"} | json | event_type="ATTACK_DETECTED" [1m])' +# refId: 'A' +# - refId: 'B' +# datasourceUid: '__expr__' +# model: +# conditions: [ { evaluator: { params: [ 0 ], type: 'gt' }, operator: { type: 'and' }, query: { params: [ 'A' ] }, reducer: { params: [], type: 'sum' }, type: 'query' } ] +# datasource: { type: '__expr__', uid: '__expr__' } +# expression: 'A' +# type: 'threshold' +# noDataState: 'OK' +# for: 0s +# annotations: +# summary: 'Attack detected on FL Server' +# description: 'The security module detected a malicious update attempting to poison the model.' +# labels: +# severity: 'critical' +# category: 'security' - - uid: 'alert-sast-critical' - title: 'ā˜£ļø Critical Vulnerability in Code' - condition: 'B' - data: - - refId: 'A' - datasourceUid: 'PBFA97CFB590B2093' - model: - expr: 'sast_bandit_issues_high' - refId: 'A' - - refId: 'B' - datasourceUid: '__expr__' - model: - conditions: [ { evaluator: { params: [ 0 ], type: 'gt' }, operator: { type: 'and' }, query: { params: [ 'A' ] }, reducer: { params: [], type: 'last' }, type: 'query' } ] - datasource: { type: '__expr__', uid: '__expr__' } - expression: 'A' - type: 'threshold' - noDataState: 'OK' - for: 1m - annotations: - summary: 'Critical SAST Issue Found' - description: 'Bandit scan found high-severity vulnerabilities in the production code.' - labels: - severity: 'warning' - category: 'code-quality' +# - uid: 'alert-sast-critical' +# title: 'ā˜£ļø Critical Vulnerability in Code' +# condition: 'B' +# data: +# - refId: 'A' +# datasourceUid: 'PBFA97CFB590B2093' +# model: +# expr: 'sast_bandit_issues_high' +# refId: 'A' +# - refId: 'B' +# datasourceUid: '__expr__' +# model: +# conditions: [ { evaluator: { params: [ 0 ], type: 'gt' }, operator: { type: 'and' }, query: { params: [ 'A' ] }, reducer: { params: [], type: 'last' }, type: 'query' } ] +# datasource: { type: '__expr__', uid: '__expr__' } +# expression: 'A' +# type: 'threshold' +# noDataState: 'OK' +# for: 1m +# annotations: +# summary: 'Critical SAST Issue Found' +# description: 'Bandit scan found high-severity vulnerabilities in the production code.' +# labels: +# severity: 'warning' +# category: 'code-quality' -contactPoints: - - orgId: 1 - name: 'Critical-Channel' - receivers: - - uid: 'email-receiver' - type: email - settings: - addresses: 'admin@example.com' - singleEmail: true +# contactPoints: +# - orgId: 1 +# name: 'Critical-Channel' +# receivers: +# - uid: 'email-receiver' +# type: email +# settings: +# addresses: 'admin@example.com' +# singleEmail: true -policies: - - orgId: 1 - receiver: 'Critical-Channel' - group_by: ['alertname'] \ No newline at end of file +# policies: +# - orgId: 1 +# receiver: 'Critical-Channel' +# group_by: ['alertname'] \ No newline at end of file diff --git a/fl-project/infra/monitoring/grafana/provisioning/dashboards/2-security-center.json b/fl-project/infra/monitoring/grafana/provisioning/dashboards/2-security-center.json index a081cef..887011b 100644 --- a/fl-project/infra/monitoring/grafana/provisioning/dashboards/2-security-center.json +++ b/fl-project/infra/monitoring/grafana/provisioning/dashboards/2-security-center.json @@ -1,38 +1,245 @@ { "title": "2. FL - Security Center", - "uid": "fl-security-center", - "tags": ["security", "attacks"], + "uid": "fl-security-center-v2", + "tags": ["security", "attacks", "federated-learning"], "timezone": "browser", "refresh": "5s", + "editable": true, "panels": [ { - "title": "Attack Detection Feed", - "type": "logs", - "gridPos": { "h": 14, "w": 14, "x": 0, "y": 0 }, - "datasource": "Loki", + "id": 1, + "title": "šŸ”“ Total Attacks Detected", + "type": "stat", + "gridPos": { "h": 6, "w": 4, "x": 0, "y": 0 }, + "datasource": "Prometheus", "targets": [ - { "expr": "{job=\"fl_server\"} | json | event_type=~\"ATTACK_DETECTED|ATTACK_REJECTED\"" } - ] + { + "expr": "sum(fl_attacks_detected_total)", + "legendFormat": "Total Attacks", + "refId": "A" + } + ], + "options": { + "reduceOptions": { + "values": false, + "calcs": ["lastNotNull"] + }, + "text": {}, + "colorMode": "background", + "graphMode": "none" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "thresholds": { + "mode": "absolute", + "steps": [ + { "color": "green", "value": 0 }, + { "color": "yellow", "value": 1 }, + { "color": "red", "value": 5 } + ] + }, + "unit": "short" + } + } + }, + { + "id": 2, + "title": "āš ļø High Confidence Attacks (>90%)", + "type": "stat", + "gridPos": { "h": 6, "w": 4, "x": 4, "y": 0 }, + "datasource": "Prometheus", + "targets": [ + { + "expr": "sum(fl_attacks_high_confidence_total)", + "legendFormat": "High Confidence", + "refId": "A" + } + ], + "options": { + "reduceOptions": { + "values": false, + "calcs": ["lastNotNull"] + }, + "colorMode": "background", + "graphMode": "none" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "thresholds": { + "mode": "absolute", + "steps": [ + { "color": "green", "value": 0 }, + { "color": "orange", "value": 1 }, + { "color": "red", "value": 3 } + ] + } + } + } + }, + { + "id": 3, + "title": "šŸ“Š Current Training Round", + "type": "stat", + "gridPos": { "h": 6, "w": 4, "x": 8, "y": 0 }, + "datasource": "Prometheus", + "targets": [ + { + "expr": "fl_training_round", + "legendFormat": "Round", + "refId": "A" + } + ], + "options": { + "colorMode": "value", + "graphMode": "none" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + } + } + } }, { - "title": "Attacks by Type", + "id": 4, + "title": "Attacks by Type (All Time)", "type": "piechart", - "gridPos": { "h": 14, "w": 10, "x": 14, "y": 0 }, + "gridPos": { "h": 10, "w": 12, "x": 12, "y": 0 }, + "datasource": "Prometheus", + "targets": [ + { + "expr": "sum by (attack_type) (fl_attacks_detected_total)", + "legendFormat": "{{attack_type}}", + "refId": "A" + } + ], + "options": { + "legend": { + "displayMode": "table", + "placement": "right", + "values": ["value"] + }, + "pieType": "donut" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + } + } + } + }, + { + "id": 5, + "title": "🚨 Attack Detection Feed (Last 100 Events)", + "type": "logs", + "gridPos": { "h": 12, "w": 12, "x": 0, "y": 6 }, "datasource": "Loki", "targets": [ - { "expr": "{job=\"fl_server\"} | json | event_type=\"ATTACK_DETECTED\" | attack_type" } - ] + { + "expr": "{container=\"fl_server\"} |= \"event_type\" | json | event_type=~\"ATTACK_DETECTED|ATTACK_REJECTED|SECURITY_ALERT\"", + "refId": "A" + } + ], + "options": { + "showTime": true, + "showLabels": false, + "showCommonLabels": false, + "wrapLogMessage": true, + "dedupStrategy": "none", + "enableLogDetails": true + } + }, + { + "id": 6, + "title": "Attacks Over Time (Rate)", + "type": "timeseries", + "gridPos": { "h": 8, "w": 12, "x": 12, "y": 10 }, + "datasource": "Prometheus", + "targets": [ + { + "expr": "sum(rate(fl_attacks_detected_total[1m])) * 60", + "legendFormat": "Attacks per minute", + "refId": "A" + } + ], + "options": { + "legend": { + "displayMode": "list", + "placement": "bottom" + } + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "drawStyle": "line", + "lineInterpolation": "smooth", + "fillOpacity": 10 + }, + "unit": "attacks/min" + } + } }, { - "title": "High Confidence Threats (>90%)", + "id": 7, + "title": "Client Security Status", "type": "table", - "gridPos": { "h": 8, "w": 24, "x": 0, "y": 14 }, - "datasource": "Loki", + "gridPos": { "h": 10, "w": 24, "x": 0, "y": 18 }, + "datasource": "Prometheus", "targets": [ - { "expr": "{job=\"fl_server\"} | json | event_type=\"ATTACK_DETECTED\" | confidence > 0.9" } + { + "expr": "fl_attacks_detected_total", + "legendFormat": "", + "refId": "A", + "format": "table", + "instant": true + }, + { + "expr": "fl_updates_accepted_total", + "refId": "B", + "format": "table", + "instant": true + }, + { + "expr": "fl_updates_rejected_total", + "refId": "C", + "format": "table", + "instant": true + } ], "transformations": [ - { "id": "organize", "options": { "renameByName": { "client_id": "Attacker", "attack_type": "Attack", "confidence": "Conf.", "timestamp": "Time" } } } + { + "id": "merge", + "options": {} + }, + { + "id": "organize", + "options": { + "excludeByName": { + "Time": true, + "job": true, + "instance": true, + "__name__": true + }, + "renameByName": { + "client_id": "Client", + "attack_type": "Attack Type", + "Value #A": "Attacks", + "Value #B": "Accepted", + "Value #C": "Rejected" + } + } + } ] } ] diff --git a/fl-project/infra/monitoring/prometheus.yml b/fl-project/infra/monitoring/prometheus.yml index 52fe25f..f390dbc 100644 --- a/fl-project/infra/monitoring/prometheus.yml +++ b/fl-project/infra/monitoring/prometheus.yml @@ -7,6 +7,7 @@ scrape_configs: static_configs: - targets: ['server:5000'] metrics_path: '/metrics' + scrape_interval: 5s - job_name: 'pushgateway' honor_labels: true diff --git a/fl-project/src/client.py b/fl-project/src/client.py index f3794e4..f39acae 100644 --- a/fl-project/src/client.py +++ b/fl-project/src/client.py @@ -139,21 +139,29 @@ def fetch_model(self): def train_locally(self, epochs=2): self.logger.info(f"Starting local training ({epochs} epochs)") - X, y = self.training_data - - history = self.model.fit( - X, y, - epochs=epochs, - batch_size=8, - verbose=0, - validation_split=0.2 - ) - - loss = float(history.history['loss'][-1]) - accuracy = float(history.history['accuracy'][-1]) - - self.logger.info(f"Training completed - loss={loss:.4f}, accuracy={accuracy:.4f}") - return loss, accuracy + try: + X, y = self.training_data + + self.logger.debug(f"Training data shape: X={X.shape}, y={y.shape}") + + history = self.model.fit( + X, y, + epochs=epochs, + batch_size=8, + verbose=0, + validation_split=0.2 + ) + + loss = float(history.history['loss'][-1]) + accuracy = float(history.history['accuracy'][-1]) + + self.logger.info(f"Training completed - loss={loss:.4f}, accuracy={accuracy:.4f}") + return loss, accuracy + except Exception as e: + self.logger.error(f"Training failed with error: {e}") + self.logger.exception(e) + # Return dummy values so client doesn't crash + return 999.9, 0.0 def submit_update(self, loss, accuracy): try: @@ -233,13 +241,34 @@ def wait_for_server_signal(self, timeout=60): timeout=timeout ) if response.status_code == 200: - return True + data = response.json() + status = data.get('status') + + # Only return True if server explicitly says to train + if status == 'go_train': + # Update our round number to match server + if 'round' in data: + server_round = data['round'] + if server_round != self.current_round: + self.logger.info(f"Server advanced to round {server_round} (was {self.current_round})") + self.current_round = server_round + return True + elif status == 'already_served_this_round': + # Client already trained this round, wait for next + self.logger.debug("Already served for current round, waiting for next round...") + return False + else: + self.logger.debug(f"Server status: {status}") + return False + else: + self.logger.warning(f"Unexpected status code from server: {response.status_code}") + return False except requests.Timeout: + self.logger.debug("Timeout waiting for signal") return False except Exception as e: self.logger.error(f"Error waiting for signal: {e}") return False - return False def main(): client_id = os.getenv('CLIENT_ID', 'client_1') @@ -263,13 +292,23 @@ def main(): # Main loop: wait for signal, train, repeat logger.info("Entering main loop - waiting for training signals") + consecutive_waits = 0 while True: try: # Wait for server to signal a training round + logger.debug(f"Polling server for training signal (attempt {consecutive_waits + 1})...") if client.wait_for_server_signal(timeout=300): logger.info(f"šŸ”” Received training signal from server") - client.run_training_cycle() + consecutive_waits = 0 # Reset counter + success = client.run_training_cycle() + if success: + logger.info("āœ“ Training cycle completed successfully") + else: + logger.warning("āš ļø Training cycle failed") else: + consecutive_waits += 1 + if consecutive_waits % 12 == 0: # Log every minute (12 * 5s) + logger.info(f"Still waiting for signal... ({consecutive_waits * 5}s elapsed)") logger.debug("No training signal received, waiting...") time.sleep(5) except KeyboardInterrupt: @@ -277,6 +316,7 @@ def main(): break except Exception as e: logger.error(f"Unexpected error: {e}") + logger.exception(e) # Print full stack trace time.sleep(5) if __name__ == '__main__': diff --git a/fl-project/src/malicious_client.py b/fl-project/src/malicious_client.py index 0ee33e6..b733a3c 100644 --- a/fl-project/src/malicious_client.py +++ b/fl-project/src/malicious_client.py @@ -20,30 +20,59 @@ ) class MaliciousClient: - def __init__(self, client_id, server_url, data_file, attack_mode='NONE', attack_rounds=None): + """ + Malicious client implementing AGGRESSIVE attack types from FL taxonomy. + + INTEGRITY ATTACKS: + - DATA_POISONING: Corrupt training data heavily + - MODEL_POISONING: Manipulate model updates aggressively (Byzantine) + - BACKDOOR: Inject strong backdoor trigger patterns + - LABEL_FLIP: Flip training labels completely + + PRIVACY ATTACKS: + - GRADIENT_INVERSION: Reconstruct training data from gradients + - MEMBERSHIP_INFERENCE: Infer if data was in training set + - PROPERTY_INFERENCE: Infer dataset properties + + AGGREGATOR/SERVER ATTACKS: + - MODEL_REPLACEMENT: Replace entire model with adversarial version + - MALICIOUS_AGGREGATION: Craft extreme updates to exploit aggregation + + ADVERSARIAL/ROBUSTNESS ATTACKS: + - ADVERSARIAL_EXAMPLES: Add strong adversarial perturbations + - MODEL_DRIFT: Cause significant model drift + - FREE_RIDING: Submit minimal/fake updates + + STEALTHY: Combination attack that tries to be more subtle + """ + + def __init__(self, client_id, server_url, data_file, attack_mode='NONE'): self.client_id = client_id self.server_url = server_url self.attack_mode = attack_mode - self.attack_rounds = attack_rounds or [2, 3, 4, 5] - self.current_round = 0 self.logger = logging.getLogger(f"CLIENT-{client_id}") - # ADDED: Security - Registration token and cryptographic keys + # Security self.registration_token = os.getenv('REGISTRATION_TOKEN') self.server_public_key = None self.private_key = ed25519.Ed25519PrivateKey.generate() self.public_key = self.private_key.public_key() - # Create model structure (with random initial weights) + # Model self.model = self._create_model() - # Load training data + # Data self.training_data = self._load_data(data_file) + self.original_data = self._load_data(data_file) # Keep clean copy + + # Attack tracking + self.intercepted_gradients = [] + self.attack_history = [] + self.free_riding_mode = False - self.logger.info(f"Initialized client (attack_mode={attack_mode})") + self.logger.info(f"āš ļø Initialized MALICIOUS client (attack_mode={attack_mode})") - # Register with server and receive initial model weights self._register_with_server() def _create_model(self): @@ -67,7 +96,6 @@ def _load_data(self, data_file): def _register_with_server(self): try: - # ADDED: Serialize public key to PEM format pem = self.public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo @@ -78,36 +106,29 @@ def _register_with_server(self): json={ 'client_id': self.client_id, 'num_samples': len(self.training_data[0]), - 'public_key': pem, # ADDED - 'token': self.registration_token # ADDED + 'public_key': pem, + 'token': self.registration_token } ) data = response.json() - # ADDED: Store server's public key for signature verification if 'server_public_key' in data: self.server_public_key = serialization.load_pem_public_key( data['server_public_key'].encode() ) - self.logger.info("Server public key received and stored") + self.logger.info("Server public key received") - # CRITICAL: Set initial weights from server if 'initial_weights' in data: weights = [np.array(w) for w in data['initial_weights']] self.model.set_weights(weights) - self.current_round = data['round'] - self.logger.info(f"Registered with server and received initial model (round {self.current_round})") - self.logger.info(f"Model synchronized with server's global model") - else: - self.logger.warning("No initial weights received from server!") - self.logger.warning("Model may not be synchronized!") + self.logger.info(f"āœ“ Registered with server") except Exception as e: self.logger.error(f"Registration failed: {e}") def fetch_model(self): - """Fetch the latest global model from server before training""" + """Fetch global model""" try: response = requests.post( f'{self.server_url}/get_model', @@ -115,36 +136,62 @@ def fetch_model(self): ) data = response.json() - # ADDED: Verify server signature if available + # Verify signature if self.server_public_key and 'signature' in data: try: payload_content = data['payload'] signature = base64.b64decode(data['signature']) payload_bytes = json.dumps(payload_content, sort_keys=True).encode() self.server_public_key.verify(signature, payload_bytes) - self.logger.info("āœ“ Global model signature VALIDATED") weights = [np.array(w) for w in payload_content['weights']] - self.current_round = payload_content['round'] except Exception as e: - self.logger.critical(f"šŸ”“ SECURITY ALERT: Server signature INVALID! {e}") + self.logger.critical(f"šŸ”“ Server signature INVALID! {e}") return False else: - # Backward compatibility: no signature weights = [np.array(w) for w in data['weights']] - self.current_round = data['round'] self.model.set_weights(weights) - self.logger.info(f"āœ“ Fetched latest global model from server (round {self.current_round})") + + # Store for privacy attacks + if self.attack_mode in ['GRADIENT_INVERSION', 'MEMBERSHIP_INFERENCE', 'PROPERTY_INFERENCE', 'STEALTHY']: + self.intercepted_gradients.append({ + 'weights': weights + }) + + self.logger.info(f"āœ“ Fetched model") return True except Exception as e: self.logger.error(f"Failed to fetch model: {e}") return False def train_locally(self, epochs=2): + """Train with potential data attacks""" self.logger.info(f"Starting local training ({epochs} epochs)") + + # FREE_RIDING: Skip training entirely + if self.attack_mode == 'FREE_RIDING': + self.logger.warning(f"šŸ”“ FREE_RIDING: Skipping training, will submit fake updates") + return 0.0, 0.0 + X, y = self.training_data + # DATA_POISONING: Use heavily corrupted data + if self.attack_mode == 'DATA_POISONING' or self.attack_mode == 'POISONING': + X, y = self._attack_data_poisoning(X, y) + + # LABEL_FLIP: Completely flip labels + elif self.attack_mode == 'LABEL_FLIP': + y = self._attack_label_flip(y) + + # ADVERSARIAL_EXAMPLES: Add strong adversarial perturbations + elif self.attack_mode == 'ADVERSARIAL_EXAMPLES': + X = self._attack_adversarial_examples(X) + + # STEALTHY: Subtle data corruption + elif self.attack_mode == 'STEALTHY': + X, y = self._attack_stealthy_data(X, y) + history = self.model.fit( X, y, epochs=epochs, @@ -159,90 +206,330 @@ def train_locally(self, epochs=2): self.logger.info(f"Training completed - loss={loss:.4f}, accuracy={accuracy:.4f}") return loss, accuracy - def _should_attack(self): - return True + # ========== INTEGRITY ATTACKS (AGGRESSIVE) ========== - def _attack_poisoning(self, weights): - """Scale weights massively to dominate aggregation""" - self.logger.warning(f"EXECUTING POISONING ATTACK - Round {self.current_round}") + def _attack_data_poisoning(self, X, y): + """ + AGGRESSIVE DATA_POISONING: Heavily corrupt training data. + This should result in very high loss and low accuracy. + """ + self.logger.warning(f"šŸ”“ EXECUTING AGGRESSIVE DATA_POISONING ATTACK") - gamma = 100.0 - # 1. Multiply the weights - # 2. Convert back to list using .tolist() so JSON can read it - poisoned = [(np.array(w) * gamma).tolist() for w in weights] + # Add MASSIVE noise to features + noise_level = 2.0 # Was 0.3 - now 2.0 (very aggressive) + X_poisoned = X + np.random.normal(0, noise_level, X.shape) + + # Flip MOST labels (80%) + flip_ratio = 0.8 # Was 0.2 - now 0.8 + num_flip = int(len(y) * flip_ratio) + flip_indices = np.random.choice(len(y), num_flip, replace=False) + y_poisoned = y.copy() + y_poisoned[flip_indices] = np.random.randint(0, 10, num_flip) - self.logger.warning(f"Scaled weights by gamma={gamma}") + self.logger.warning(f"šŸ’„ Corrupted {num_flip} samples ({flip_ratio*100}%) with heavy noise") + return X_poisoned, y_poisoned + + def _attack_label_flip(self, labels): + """AGGRESSIVE LABEL_FLIP: Completely reverse all labels""" + self.logger.warning(f"šŸ”“ EXECUTING AGGRESSIVE LABEL_FLIP ATTACK") + num_classes = int(np.max(labels)) + 1 + # Completely reverse the labels + flipped = (num_classes - 1 - labels) % num_classes + self.logger.warning(f"šŸ’„ Flipped ALL {len(labels)} labels (complete reversal)") + return flipped + + def _attack_model_poisoning(self, weights): + """ + AGGRESSIVE MODEL_POISONING: Byzantine attack with MASSIVE scaling. + This should be very obvious to detect. + """ + self.logger.warning(f"šŸ”“ EXECUTING AGGRESSIVE MODEL_POISONING ATTACK") + gamma = 1000.0 # Was 100.0 - now 1000.0 (10x more aggressive) + poisoned = [(np.array(w) * gamma).tolist() for w in weights] + self.logger.warning(f"šŸ’„ Scaled weights by MASSIVE gamma={gamma}") return poisoned - - def _attack_stealthy(self, weights): - """Constrain-and-scale - hide attack in normal magnitude but still detectable""" - self.logger.warning(f"EXECUTING STEALTHY ATTACK - Round {self.current_round}") + + def _attack_backdoor(self, weights): + """ + AGGRESSIVE BACKDOOR: Inject STRONG trigger pattern in output layer. + """ + self.logger.warning(f"šŸ”“ EXECUTING AGGRESSIVE BACKDOOR ATTACK") + + backdoored = [] + for i, w in enumerate(weights): + w_arr = np.array(w) + if i == len(weights) - 1: # Output layer + # Much stronger trigger pattern + trigger_pattern = np.random.uniform(-1.0, 1.0, w_arr.shape) # Was ±0.1, now ±1.0 + w_backdoor = w_arr + trigger_pattern * 10.0 # Was 2.0, now 10.0 + else: + w_backdoor = w_arr + backdoored.append(w_backdoor.tolist()) + + self.logger.warning(f"šŸ’„ Injected STRONG backdoor trigger in output layer") + return backdoored + + # ========== AGGREGATOR/SERVER ATTACKS (AGGRESSIVE) ========== + + def _attack_model_replacement(self, weights): + """ + AGGRESSIVE MODEL_REPLACEMENT: Replace with completely adversarial model. + """ + self.logger.warning(f"šŸ”“ EXECUTING AGGRESSIVE MODEL_REPLACEMENT ATTACK") + + # Create adversarial model + adversarial_model = self._create_model() + X, y = self.original_data + + # Completely reverse labels for adversarial training + y_adv = (10 - 1 - y) % 10 # Complete reversal + + adversarial_model.fit(X, y_adv, epochs=3, batch_size=8, verbose=0) # More epochs + + # Get adversarial weights + adv_weights = [w.tolist() for w in adversarial_model.get_weights()] + + # Scale MASSIVELY to dominate aggregation + scale_factor = 100.0 # Was 10.0 - now 100.0 + replaced = [(np.array(w) * scale_factor).tolist() for w in adv_weights] + + self.logger.warning(f"šŸ’„ Replaced model with STRONG adversarial version (scaled by {scale_factor})") + return replaced + + def _attack_malicious_aggregation(self, weights): + """ + AGGRESSIVE MALICIOUS_AGGREGATION: Create updates with VERY low entropy. + This creates a uniform directional attack that should be easily detected. + """ + self.logger.warning(f"šŸ”“ EXECUTING AGGRESSIVE MALICIOUS_AGGREGATION ATTACK") + + malicious = [] + # Create a SINGLE direction for ALL weights (very low entropy) + global_direction = np.random.randn(1)[0] # Single number for all weights + + for w in weights: + w_arr = np.array(w) + + # Push ALL weights in the SAME direction (extremely low entropy) + # This creates a uniform update that will have entropy < 1.0 + w_mal = w_arr + global_direction * np.abs(w_arr) * 5.0 # Was 1.5, now 5.0 + + malicious.append(w_mal.tolist()) + + self.logger.warning(f"šŸ’„ Crafted UNIFORM directional updates (very low entropy, scale=5.0)") + return malicious + + # ========== ADVERSARIAL/ROBUSTNESS ATTACKS (AGGRESSIVE) ========== + + def _attack_adversarial_examples(self, X): + """ + AGGRESSIVE ADVERSARIAL_EXAMPLES: Add STRONG perturbations. + """ + self.logger.warning(f"šŸ”“ EXECUTING AGGRESSIVE ADVERSARIAL_EXAMPLES ATTACK") + + # Much larger epsilon + epsilon = 1.0 # Was 0.1, now 1.0 (10x stronger) + perturbations = np.random.uniform(-epsilon, epsilon, X.shape) + X_adv = X + perturbations + + self.logger.warning(f"šŸ’„ Added STRONG adversarial perturbations (epsilon={epsilon})") + return X_adv + + def _attack_model_drift(self, weights): + """ + AGGRESSIVE MODEL_DRIFT: Cause SIGNIFICANT drift. + """ + self.logger.warning(f"šŸ”“ EXECUTING AGGRESSIVE MODEL_DRIFT ATTACK") + + # Much higher drift rate + drift_rate = 0.5 # Was 0.05, now 0.5 (10x stronger) + drifted = [] - # More aggressive - constrained = [] for w in weights: w_arr = np.array(w) - mean = np.mean(w_arr) - std = np.std(w_arr) - # Scale by 1.5 and add controlled noise - perturbed = w_arr * 1.5 + np.random.normal(0, std * 0.1, w_arr.shape) - constrained.append(perturbed.tolist()) - - self.logger.warning(f"Applied stealthy perturbation (1.5x scale + noise)") - return constrained - - def _attack_sybil(self, weights): - """Create correlated updates simulating multiple clients""" - self.logger.warning(f"EXECUTING SYBIL ATTACK - Round {self.current_round}") - - # Create weights that deviate from global model - base = [] + # Strong consistent directional drift + drift = np.random.randn(*w_arr.shape) * np.std(w_arr) * drift_rate + w_drift = w_arr + drift + drifted.append(w_drift.tolist()) + + self.logger.warning(f"šŸ’„ Applied STRONG model drift (rate={drift_rate})") + return drifted + + def _attack_free_riding(self, weights): + """ + FREE_RIDING: Submit essentially ZERO updates. + """ + self.logger.warning(f"šŸ”“ EXECUTING FREE_RIDING ATTACK") + + # Submit weights with TINY changes + fake_updates = [] for w in weights: w_arr = np.array(w) - # Create significantly different weights to simulate multiple malicious clients - # Use 3x scaling to make it detectable - perturbed = w_arr * 3.0 + np.random.normal(0, 1.0, w_arr.shape) - base.append(perturbed.tolist()) + # Extremely tiny noise (should trigger free-riding detection) + noise = np.random.normal(0, np.std(w_arr) * 0.00001, w_arr.shape) # Was 0.01, now 0.00001 + fake_w = w_arr + noise + fake_updates.append(fake_w.tolist()) - self.logger.warning(f"Created sybil simulation (3x scale)") - return base + self.logger.warning(f"šŸ’„ Submitted MINIMAL fake updates (free-riding)") + return fake_updates + + # ========== STEALTHY ATTACK (Combination) ========== + + def _attack_stealthy_data(self, X, y): + """ + STEALTHY: Subtle combination attack. + Less aggressive but still detectable. + """ + self.logger.warning(f"šŸ”“ EXECUTING STEALTHY ATTACK (subtle data corruption)") + + # Light noise + noise_level = 0.2 + X_poisoned = X + np.random.normal(0, noise_level, X.shape) + + # Flip only 10% of labels + flip_ratio = 0.1 + num_flip = int(len(y) * flip_ratio) + flip_indices = np.random.choice(len(y), num_flip, replace=False) + y_poisoned = y.copy() + y_poisoned[flip_indices] = (y_poisoned[flip_indices] + 1) % 10 + + self.logger.warning(f"šŸ”¶ Subtle corruption: {num_flip} samples ({flip_ratio*100}%)") + return X_poisoned, y_poisoned + + # ========== PRIVACY ATTACKS ========== def _attack_gradient_inversion(self, weights): - """Amplify gradients to expose training data""" - self.logger.warning(f"EXECUTING GRADIENT INVERSION ATTACK - Round {self.current_round}") + """GRADIENT_INVERSION: Analyze gradients to reconstruct training data""" + self.logger.warning(f"šŸ”“ EXECUTING GRADIENT_INVERSION ATTACK") + + if len(self.intercepted_gradients) > 1: + prev_weights = self.intercepted_gradients[-2]['weights'] + curr_weights = weights + + gradient_diffs = [] + for prev_w, curr_w in zip(prev_weights, curr_weights): + diff = np.array(curr_w) - np.array(prev_w) + gradient_diffs.append(diff) + + self.attack_history.append({ + 'attack': 'GRADIENT_INVERSION', + 'gradient_stats': { + 'mean': float(np.mean([np.mean(g) for g in gradient_diffs])), + 'std': float(np.std([np.std(g) for g in gradient_diffs])) + } + }) + self.logger.warning(f"šŸ“Š Analyzed gradients from {len(self.intercepted_gradients)} rounds") + + return weights + + def _attack_membership_inference(self, weights): + """MEMBERSHIP_INFERENCE: Infer if data was in training""" + self.logger.warning(f"šŸ”“ EXECUTING MEMBERSHIP_INFERENCE ATTACK") - # Amplify each weight array - amplified = [] + X_test, y_test = self.training_data + sample_indices = np.random.choice(len(X_test), min(10, len(X_test)), replace=False) + + predictions = self.model.predict(X_test[sample_indices], verbose=0) + confidences = np.max(predictions, axis=1) + + self.attack_history.append({ + 'attack': 'MEMBERSHIP_INFERENCE', + 'avg_confidence': float(np.mean(confidences)), + 'num_samples': len(sample_indices) + }) + self.logger.warning(f"šŸ“Š Membership inference on {len(sample_indices)} samples") + + return weights + + def _attack_property_inference(self, weights): + """PROPERTY_INFERENCE: Infer properties of other clients' data""" + self.logger.warning(f"šŸ”“ EXECUTING PROPERTY_INFERENCE ATTACK") + + weight_stats = [] for w in weights: w_arr = np.array(w) - # Amplify by 20x and convert to list - amplified.append((w_arr * 20.0).tolist()) + weight_stats.append({ + 'mean': float(np.mean(w_arr)), + 'std': float(np.std(w_arr)), + 'min': float(np.min(w_arr)), + 'max': float(np.max(w_arr)) + }) + + self.attack_history.append({ + 'attack': 'PROPERTY_INFERENCE', + 'weight_stats': weight_stats + }) + self.logger.warning(f"šŸ“Š Inferred properties from weight distributions") - self.logger.warning(f"Amplified gradients by 20x for DLG attack") - return amplified + return weights def submit_update(self, loss, accuracy): + """Submit update with potential attacks""" weights = [w.tolist() for w in self.model.get_weights()] + is_attack = False + attack_type = 'NONE' - # Check if should attack - if self.attack_mode != 'NONE' and self._should_attack(): - self.logger.warning(f"šŸ”“ ATTACK TRIGGERED IN ROUND {self.current_round}") - - if self.attack_mode == 'POISONING': - weights = self._attack_poisoning(weights) - elif self.attack_mode == 'STEALTHY': - weights = self._attack_stealthy(weights) - elif self.attack_mode == 'SYBIL': - weights = self._attack_sybil(weights) - elif self.attack_mode == 'GRADIENT_INVERSION': - weights = self._attack_gradient_inversion(weights) - - is_poisoned = True - else: - is_poisoned = False + # Apply attacks based on mode + if self.attack_mode == 'MODEL_POISONING' or self.attack_mode == 'POISONING': + weights = self._attack_model_poisoning(weights) + is_attack = True + attack_type = 'MODEL_POISONING' + + elif self.attack_mode == 'BACKDOOR': + weights = self._attack_backdoor(weights) + is_attack = True + attack_type = 'BACKDOOR' + + elif self.attack_mode == 'MODEL_REPLACEMENT': + weights = self._attack_model_replacement(weights) + is_attack = True + attack_type = 'MODEL_REPLACEMENT' + + elif self.attack_mode == 'MALICIOUS_AGGREGATION': + weights = self._attack_malicious_aggregation(weights) + is_attack = True + attack_type = 'MALICIOUS_AGGREGATION' + + elif self.attack_mode == 'MODEL_DRIFT': + weights = self._attack_model_drift(weights) + is_attack = True + attack_type = 'MODEL_DRIFT' + + elif self.attack_mode == 'FREE_RIDING': + weights = self._attack_free_riding(weights) + is_attack = True + attack_type = 'FREE_RIDING' + + # Privacy attacks (don't modify weights but still mark as attack) + elif self.attack_mode == 'GRADIENT_INVERSION': + weights = self._attack_gradient_inversion(weights) + is_attack = True + attack_type = 'GRADIENT_INVERSION' + + elif self.attack_mode == 'MEMBERSHIP_INFERENCE': + weights = self._attack_membership_inference(weights) + is_attack = True + attack_type = 'MEMBERSHIP_INFERENCE' + + elif self.attack_mode == 'PROPERTY_INFERENCE': + weights = self._attack_property_inference(weights) + is_attack = True + attack_type = 'PROPERTY_INFERENCE' + + # Stealthy: combination attack + elif self.attack_mode == 'STEALTHY': + # Data corruption already done during training + # Also do gradient inversion + weights = self._attack_gradient_inversion(weights) + is_attack = True + attack_type = 'STEALTHY' + + # Data attacks already applied during training + if self.attack_mode in ['DATA_POISONING', 'LABEL_FLIP', 'ADVERSARIAL_EXAMPLES']: + is_attack = True + attack_type = self.attack_mode try: - # MODIFIED: Create signed payload payload_content = { 'client_id': self.client_id, 'weights': weights, @@ -250,13 +537,12 @@ def submit_update(self, loss, accuracy): 'loss': loss, 'accuracy': accuracy, 'timestamp': datetime.now().isoformat(), - 'is_poisoned': is_poisoned, - 'attack_type': self.attack_mode if is_poisoned else 'NONE' + 'is_attack': is_attack, + 'attack_type': attack_type }, - 'round': self.current_round # ADDED } - # ADDED: Sign the payload (even malicious updates are signed) + # Sign payload payload_bytes = json.dumps(payload_content, sort_keys=True).encode() signature = self.private_key.sign(payload_bytes) signature_b64 = base64.b64encode(signature).decode('utf-8') @@ -272,41 +558,28 @@ def submit_update(self, loss, accuracy): ) if response.status_code == 200: - if is_poisoned: - self.logger.warning(f"Poisoned update submitted successfully") + if is_attack: + self.logger.warning(f"šŸ’„ Attack update submitted ({attack_type})") else: - self.logger.info(f"Update submitted successfully") + self.logger.info(f"āœ“ Update submitted") return True else: - self.logger.warning(f"Update rejected: {response.text}") + self.logger.warning(f"āŒ Update rejected: {response.text}") return False except Exception as e: - self.logger.error(f"Failed to submit update: {e}") + self.logger.error(f"Failed to submit: {e}") return False def run_training_cycle(self): - """ - Complete training cycle: - 1. Fetch latest global model from server - 2. Train locally (potentially with poisoned data) - 3. Submit weight updates (potentially poisoned) - """ + """Complete training cycle""" self.logger.info(f"╔══ Starting training cycle ══╗") - # STEP 1: Fetch latest global model - self.logger.info(f"[1/3] Fetching latest global model...") if not self.fetch_model(): - self.logger.error(f"Failed to fetch model, aborting cycle") return False - # STEP 2: Train locally - self.logger.info(f"[2/3] Training locally...") loss, accuracy = self.train_locally(epochs=2) - # STEP 3: Submit update (may be poisoned) - self.logger.info(f"[3/3] Submitting update to server...") if not self.submit_update(loss, accuracy): - self.logger.error(f"Failed to submit update") return False self.logger.info(f"ā•šā•ā• Training cycle completed ā•ā•ā•") @@ -315,19 +588,37 @@ def run_training_cycle(self): def wait_for_server_signal(self, timeout=60): """Poll server for training signal""" try: + self.logger.debug(f"Waiting for server signal (timeout={timeout}s)...") response = requests.post( f'{self.server_url}/wait_for_round', json={'client_id': self.client_id}, timeout=timeout ) + self.logger.debug(f"Received response: {response.status_code}") if response.status_code == 200: - return True + data = response.json() + status = data.get('status') + + # Only return True if server explicitly says to train + if status == 'go_train': + self.logger.info(f"āœ“ Signal received from server - GO TRAIN") + return True + elif status == 'already_served_this_round': + # Client already trained this round, wait for next + self.logger.debug("Already served for current round, waiting for next...") + return False + else: + self.logger.warning(f"Unexpected status: {status}") + return False + else: + self.logger.warning(f"Unexpected status code: {response.status_code}") + return False except requests.Timeout: + self.logger.warning(f"ā±ļø Timeout waiting for signal ({timeout}s)") return False except Exception as e: - self.logger.error(f"Error waiting for signal: {e}") + self.logger.error(f"āŒ Error waiting for signal: {e}") return False - return False def main(): client_id = os.getenv('CLIENT_ID', 'malicious_client') @@ -337,35 +628,40 @@ def main(): logger = logging.getLogger(f"CLIENT-{client_id}") - # Wait for server to start - logger.info("Waiting for server to be ready...") + # Wait for server + logger.info("Waiting for server...") for attempt in range(10): try: requests.get(f'{server_url}/health', timeout=2) - logger.info("Server is ready") + logger.info("āœ“ Server ready") break except: - logger.info(f"Waiting for server ({attempt + 1}/10)") + logger.info(f"Waiting ({attempt + 1}/10)") time.sleep(2) - # Initialize client (this will sync with server's initial model) client = MaliciousClient(client_id, server_url, data_file, attack_mode=attack_mode) - # Main loop: wait for signal, train (potentially with attack), repeat logger.info("Entering main loop - waiting for training signals") + consecutive_timeouts = 0 while True: try: + logger.debug(f"Polling for training signal...") if client.wait_for_server_signal(timeout=300): - logger.info(f"šŸ”” Received training signal from server") + logger.info(f"šŸ”” Training signal received - starting training cycle") + consecutive_timeouts = 0 client.run_training_cycle() else: - logger.debug("No training signal received, waiting...") + consecutive_timeouts += 1 + if consecutive_timeouts % 12 == 0: # Log every hour (5s * 12 = 60s) + logger.info(f"Still waiting for training signal... ({consecutive_timeouts} polls)") time.sleep(5) except KeyboardInterrupt: - logger.info("Client shutting down") + logger.info("Shutting down") break except Exception as e: - logger.error(f"Unexpected error: {e}") + logger.error(f"Unexpected error in main loop: {e}") + import traceback + traceback.print_exc() time.sleep(5) if __name__ == '__main__': diff --git a/fl-project/src/server.py b/fl-project/src/server.py index 0260d3a..5b344a5 100644 --- a/fl-project/src/server.py +++ b/fl-project/src/server.py @@ -32,6 +32,7 @@ def __init__(self): self.client_states = {} self.client_updates = {} self.waiting_clients = defaultdict(threading.Event) + self.clients_served_this_round = set() # Track which clients already received signal for current round self.round_in_progress = False self.round_lock = threading.Lock() @@ -39,31 +40,64 @@ def __init__(self): self.audit_log = [] self.detected_attacks = [] - # ADDED: Security - Token & Keys + # Security - Token & Keys self.registration_token = os.getenv('REGISTRATION_TOKEN', 'default_insecure_token') self.client_keys = {} self.server_private_key = ed25519.Ed25519PrivateKey.generate() self.server_public_key = self.server_private_key.public_key() - logger.info(f"Server initialized for {self.max_rounds} rounds") + # DETECTION THRESHOLDS - FIXED TO REDUCE FALSE POSITIVES + self.detection_config = { + # Integrity attack detection - LOOSENED + 'l2_multiplier': 5.0, # Was 2.0 - now allows more deviation + 'cosine_threshold': 0.3, # Was 0.5 - now allows more directional variation + 'mean_ratio_threshold': 3.0, # Was 1.3 - much more permissive + 'std_ratio_threshold': 3.0, # Was 1.5 - much more permissive + + # Sign flip detection + 'negative_cosine_threshold': -0.3, + + # Gaussian noise detection + 'noise_std_multiplier': 10.0, # Was 3.0 - more tolerant + + # Backdoor detection + 'layer_divergence_threshold': 5.0, # Was 2.5 - more tolerant + + # Malicious aggregation - FIXED: This was causing false positives! + 'entropy_threshold': 1.0, # Was 2.5 - MUCH lower threshold (normal updates have ~3-4 entropy) + + # Privacy attack detection + 'gradient_access_limit': 5, + 'confidence_threshold': 0.95, + + # Free riding + 'free_riding_threshold': 0.0001, # More realistic + + # Data poisoning + 'data_poison_loss_threshold': 15.0, # Was 10.0 + 'data_poison_accuracy_threshold': 0.005, # Was 0.01 + } + + # Historical data for detection + self.weight_history = [] + self.client_access_count = defaultdict(int) + + logger.info(f"āœ“ FL Server initialized for {self.max_rounds} rounds") + logger.info(f"Detection enabled: Integrity + Privacy attacks") def _load_or_create_model(self): """Load pre-trained model from file, or create new one if not found""" - # Define paths h5_path = os.getenv('SERVER_MODEL_PATH', './data/global_model.h5') - json_path = './data/global_model_weights.json' # Your JSON source + json_path = './data/global_model_weights.json' - # 1. Try JSON first (Most reliable across different environments) + # Try JSON first if os.path.exists(json_path): try: logger.info(f"Loading weights from JSON: {json_path}") with open(json_path, 'r') as f: data = json.load(f) - # Create the architecture first model = self._create_model() - - # Convert list of lists to numpy arrays weights_np = [np.array(w) for w in data['weights']] model.set_weights(weights_np) @@ -72,45 +106,40 @@ def _load_or_create_model(self): except Exception as e: logger.warning(f"Failed to load from JSON: {e}") - # 2. Try H5 as fallback + # Try H5 as fallback if os.path.exists(h5_path): try: logger.info(f"Attempting to load H5 model: {h5_path}") - # safe_mode=False helps bypass some Keras 3 metadata strictness model = tf.keras.models.load_model(h5_path, safe_mode=False) logger.info("Successfully loaded H5 model") return model except Exception as e: logger.warning(f"H5 load failed: {e}") - # 3. Last resort: Create new + # Create new logger.info("Creating brand new model from scratch") return self._create_model() def _create_model(self): - """Create a new global model with clean Keras 3 architecture""" + """Create a new global model""" model = tf.keras.Sequential([ - # We remove input_length=3 because it's deprecated in Keras 3 - # and was causing those warnings in your logs tf.keras.layers.Embedding(10000, 100), tf.keras.layers.LSTM(150), tf.keras.layers.Dense(10000, activation='softmax') ]) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) - - # Explicitly build with the expected input shape model.build(input_shape=(None, 3)) logger.info("Created architecture: Embedding(10000, 100) → LSTM(150) → Dense(10000)") return model def register_client(self, client_id, num_samples, token=None, public_key_pem=None): - # ADDED: Verify registration token + # Verify registration token if token != self.registration_token: logger.warning(f"šŸ”“ Unauthorized registration attempt from {client_id}") return {'status': 'rejected', 'reason': 'Invalid Registration Token'}, 403 - # ADDED: Store client's public key + # Store client's public key if public_key_pem: try: public_key = serialization.load_pem_public_key(public_key_pem.encode()) @@ -120,6 +149,15 @@ def register_client(self, client_id, num_samples, token=None, public_key_pem=Non logger.error(f"Invalid public key from {client_id}: {e}") return {'status': 'rejected', 'reason': 'Invalid Public Key Format'}, 400 + # If client is re-registering (e.g., container was recreated), + # remove them from the served set so they can participate in current round + if client_id in self.client_states: + logger.info(f"Client {client_id} re-registering (container recreated)") + # Remove from served set to allow participation in current round + self.clients_served_this_round.discard(client_id) + # Clear their event and re-set it so they can receive the signal + self.waiting_clients[client_id].clear() + if client_id not in self.client_states: self.client_states[client_id] = { 'registered_at': datetime.now().isoformat(), @@ -128,16 +166,17 @@ def register_client(self, client_id, num_samples, token=None, public_key_pem=Non 'updates_rejected': 0, 'num_samples': num_samples, 'last_metrics': None, - 'attacks_detected': [] + 'attacks_detected': [], + 'attack_types': defaultdict(int) } logger.info(f"Client registered: {client_id} ({num_samples} samples)") self._log_audit('CLIENT_REGISTERED', {'client_id': client_id, 'num_samples': num_samples}) - # SEND INITIAL MODEL WEIGHTS TO CLIENT + # Send initial model weights weights = [w.tolist() for w in self.global_model.get_weights()] logger.info(f"Sending initial model weights to {client_id} (round {self.round})") - # ADDED: Send server's public key + # Send server's public key server_pub_pem = self.server_public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo @@ -148,11 +187,11 @@ def register_client(self, client_id, num_samples, token=None, public_key_pem=Non 'round': self.round, 'client_id': client_id, 'initial_weights': weights, - 'server_public_key': server_pub_pem # ADDED + 'server_public_key': server_pub_pem }, 200 def verify_signature(self, client_id, payload_bytes, signature_b64): - """ADDED: Verify the digital signature of an update""" + """Verify the digital signature of an update""" if client_id not in self.client_keys: logger.error(f"šŸ”“ Unknown client {client_id} attempted update") return False @@ -167,24 +206,30 @@ def verify_signature(self, client_id, payload_bytes, signature_b64): logger.error(f"šŸ”“ Signature verification failed for {client_id}: {e}") return False + # ========== INTEGRITY ATTACK DETECTION ========== + def detect_poisoning(self, client_id, weights, global_weights): - """Detect poisoning attacks using multiple statistical methods""" + """ + Multi-method poisoning detection: + - L2 norm detection (Byzantine attacks) + - Cosine similarity (direction attacks) + - Statistical divergence (distribution attacks) + """ try: - # Flatten weights for analysis client_flat = np.concatenate([w.flatten() for w in weights]) global_flat = np.concatenate([w.flatten() for w in global_weights]) - # METHOD 1: L2 norm detection (catches large-scale attacks) + # METHOD 1: L2 norm detection l2_norm = np.linalg.norm(client_flat - global_flat) - threshold_l2 = np.linalg.norm(global_flat) * 2.0 + threshold_l2 = np.linalg.norm(global_flat) * self.detection_config['l2_multiplier'] - # METHOD 2: Cosine similarity (catches direction changes) + # METHOD 2: Cosine similarity cosine_sim = np.dot(client_flat, global_flat) / ( np.linalg.norm(client_flat) * np.linalg.norm(global_flat) + 1e-8 ) - threshold_cosine = 0.5 # Low similarity = suspicious + threshold_cosine = self.detection_config['cosine_threshold'] - # METHOD 3: Statistical divergence (catches distribution changes) + # METHOD 3: Statistical divergence mean_val = np.mean(np.abs(client_flat)) std_val = np.std(np.abs(client_flat)) global_mean = np.mean(np.abs(global_flat)) @@ -193,9 +238,8 @@ def detect_poisoning(self, client_id, weights, global_weights): mean_ratio = mean_val / (global_mean + 1e-8) std_ratio = std_val / (global_std + 1e-8) - # Suspicious if mean or std deviates significantly - threshold_mean_ratio = 1.3 # 30% deviation - threshold_std_ratio = 1.5 # 50% deviation + threshold_mean_ratio = self.detection_config['mean_ratio_threshold'] + threshold_std_ratio = self.detection_config['std_ratio_threshold'] # Detect if ANY method triggers detected = False @@ -230,9 +274,7 @@ def detect_poisoning(self, client_id, weights, global_weights): 'cosine_similarity': float(cosine_sim), 'cosine_threshold': float(threshold_cosine), 'mean_ratio': float(mean_ratio), - 'std_ratio': float(std_ratio), - 'mean': float(mean_val), - 'std': float(std_val) + 'std_ratio': float(std_ratio) } return False, 0.0, {} @@ -240,63 +282,519 @@ def detect_poisoning(self, client_id, weights, global_weights): logger.error(f"Error in poisoning detection: {e}") return False, 0.0, {} + def detect_sign_flip(self, client_id, weights, global_weights): + """ + Detect sign flip attacks (gradient ascent instead of descent). + Check if gradients point in opposite direction. + """ + try: + client_flat = np.concatenate([w.flatten() for w in weights]) + global_flat = np.concatenate([w.flatten() for w in global_weights]) + + # Compute cosine similarity + cosine_sim = np.dot(client_flat, global_flat) / ( + np.linalg.norm(client_flat) * np.linalg.norm(global_flat) + 1e-8 + ) + + # Negative cosine indicates opposite direction (sign flip) + if cosine_sim < self.detection_config['negative_cosine_threshold']: + return True, 0.9, { + 'cosine_similarity': float(cosine_sim), + 'detection': 'Gradient points in opposite direction (likely sign flip)' + } + + return False, 0.0, {} + except Exception as e: + logger.error(f"Error in sign flip detection: {e}") + return False, 0.0, {} + + def detect_gaussian_noise(self, client_id, weights, global_weights): + """ + Detect excessive Gaussian noise injection. + Check if weight variance is unusually high. + """ + try: + client_flat = np.concatenate([w.flatten() for w in weights]) + global_flat = np.concatenate([w.flatten() for w in global_weights]) + + # Compute difference + diff = client_flat - global_flat + + # Check if variance of difference is unusually high + diff_std = np.std(diff) + global_std = np.std(global_flat) + + noise_ratio = diff_std / (global_std + 1e-8) + + if noise_ratio > self.detection_config['noise_std_multiplier']: + return True, 0.85, { + 'noise_ratio': float(noise_ratio), + 'diff_std': float(diff_std), + 'global_std': float(global_std), + 'detection': 'Excessive noise variance detected' + } + + return False, 0.0, {} + except Exception as e: + logger.error(f"Error in noise detection: {e}") + return False, 0.0, {} + + def detect_backdoor(self, client_id, weights, global_weights): + """ + Detect backdoor attacks. + Check for unusual layer-specific divergence patterns. + """ + try: + layer_divergences = [] + + for i, (client_w, global_w) in enumerate(zip(weights, global_weights)): + client_arr = np.array(client_w).flatten() + global_arr = np.array(global_w).flatten() + + # Compute layer-specific L2 norm + layer_l2 = np.linalg.norm(client_arr - global_arr) + global_layer_l2 = np.linalg.norm(global_arr) + + divergence_ratio = layer_l2 / (global_layer_l2 + 1e-8) + layer_divergences.append(divergence_ratio) + + # Check if last layer (output layer) has unusually high divergence + # Backdoors often target output layer + if len(layer_divergences) > 0: + last_layer_divergence = layer_divergences[-1] + avg_divergence = np.mean(layer_divergences[:-1]) if len(layer_divergences) > 1 else 0 + + if last_layer_divergence > avg_divergence * self.detection_config['layer_divergence_threshold']: + return True, 0.8, { + 'last_layer_divergence': float(last_layer_divergence), + 'avg_other_layers': float(avg_divergence), + 'layer_divergences': [float(d) for d in layer_divergences], + 'detection': 'Unusual output layer divergence (backdoor pattern)' + } + + return False, 0.0, {} + except Exception as e: + logger.error(f"Error in backdoor detection: {e}") + return False, 0.0, {} + + def detect_model_replacement(self, client_id, weights, global_weights): + """ + Detect MODEL_REPLACEMENT attacks. + Check if entire model structure differs significantly from global. + """ + try: + # Check if ALL layers diverge significantly + all_divergent = True + divergence_scores = [] + + for client_w, global_w in zip(weights, global_weights): + client_arr = np.array(client_w).flatten() + global_arr = np.array(global_w).flatten() + + # Compute correlation + correlation = np.corrcoef(client_arr, global_arr)[0, 1] + divergence_scores.append(1 - correlation) # Higher = more divergent + + # If all layers are divergent, likely model replacement + avg_divergence = np.mean(divergence_scores) + + if avg_divergence > 0.7: # 70% divergence across all layers + return True, 0.90, { + 'avg_divergence': float(avg_divergence), + 'layer_divergences': [float(d) for d in divergence_scores], + 'detection': 'Entire model structure divergent (model replacement)' + } + + return False, 0.0, {} + except Exception as e: + logger.error(f"Error in model replacement detection: {e}") + return False, 0.0, {} + + def detect_malicious_aggregation(self, client_id, weights, global_weights): + """ + FIXED: Detect MALICIOUS_AGGREGATION attacks. + Check for carefully crafted updates designed to exploit averaging. + + The key fix: Lower the entropy threshold to 1.0 instead of 2.5. + Normal gradient updates have entropy around 3-4, so only extremely + low entropy (< 1.0) indicates a coordinated attack. + """ + try: + client_flat = np.concatenate([w.flatten() for w in weights]) + global_flat = np.concatenate([w.flatten() for w in global_weights]) + + # Check if update has unusual directionality + diff = client_flat - global_flat + + # Compute entropy of differences (uniform direction = low entropy) + hist, _ = np.histogram(diff, bins=50) + hist = hist / (np.sum(hist) + 1e-8) + entropy = -np.sum(hist * np.log(hist + 1e-8)) + + # FIXED: Very low entropy = coordinated attack + # Normal updates have entropy ~3-4, attacks have < 1.0 + if entropy < self.detection_config['entropy_threshold']: + return True, 0.75, { + 'entropy': float(entropy), + 'detection': f'Extremely low-entropy directional update (entropy={entropy:.2f}, threshold={self.detection_config["entropy_threshold"]})' + } + + return False, 0.0, {} + except Exception as e: + logger.error(f"Error in malicious aggregation detection: {e}") + return False, 0.0, {} + + def detect_model_drift(self, client_id, weights): + """ + Detect MODEL_DRIFT attacks. + Track if client consistently deviates in same direction over rounds. + """ + try: + # Need historical data + if len(self.weight_history) < 3: + return False, 0.0, {} + + # Get client's last 3 updates + client_updates = [] + for hist in self.weight_history[-3:]: + if client_id in hist: + client_updates.append(hist[client_id]) + + if len(client_updates) < 3: + return False, 0.0, {} + + # Compute directions of changes + directions = [] + for i in range(len(client_updates) - 1): + curr = np.concatenate([np.array(w).flatten() for w in client_updates[i]]) + next_w = np.concatenate([np.array(w).flatten() for w in client_updates[i+1]]) + direction = next_w - curr + direction = direction / (np.linalg.norm(direction) + 1e-8) + directions.append(direction) + + # Check if directions are consistent (drift) + if len(directions) >= 2: + similarity = np.dot(directions[0], directions[1]) + if similarity > 0.9: # Very consistent direction + return True, 0.70, { + 'direction_similarity': float(similarity), + 'detection': 'Consistent directional drift across rounds' + } + + return False, 0.0, {} + except Exception as e: + logger.error(f"Error in drift detection: {e}") + return False, 0.0, {} + + def detect_free_riding(self, client_id, weights, global_weights): + """ + Detect FREE_RIDING attacks. + Check if updates are suspiciously minimal/fake. + """ + try: + client_flat = np.concatenate([w.flatten() for w in weights]) + global_flat = np.concatenate([w.flatten() for w in global_weights]) + + # Compute magnitude of update + update_magnitude = np.linalg.norm(client_flat - global_flat) + global_magnitude = np.linalg.norm(global_flat) + + # Relative update size + relative_update = update_magnitude / (global_magnitude + 1e-8) + + # If update is suspiciously tiny + if relative_update < self.detection_config['free_riding_threshold']: + return True, 0.65, { + 'relative_update_size': float(relative_update), + 'detection': 'Suspiciously minimal update (free-riding)' + } + + return False, 0.0, {} + except Exception as e: + logger.error(f"Error in free-riding detection: {e}") + return False, 0.0, {} + + def detect_data_poisoning(self, client_id, metrics): + """ + Detect DATA_POISONING attacks. + Check for unusual training metrics that indicate corrupted data. + """ + try: + if 'loss' not in metrics or 'accuracy' not in metrics: + return False, 0.0, {} + + loss = metrics['loss'] + accuracy = metrics['accuracy'] + + # Suspiciously high loss or low accuracy + if loss > self.detection_config['data_poison_loss_threshold'] or accuracy < self.detection_config['data_poison_accuracy_threshold']: + return True, 0.75, { + 'loss': loss, + 'accuracy': accuracy, + 'detection': 'Abnormal training metrics (data poisoning)' + } + + return False, 0.0, {} + except Exception as e: + logger.error(f"Error in data poisoning detection: {e}") + return False, 0.0, {} + + def detect_adversarial_examples(self, client_id, weights, global_weights): + """ + Detect ADVERSARIAL_EXAMPLES in training. + Similar to noise detection but with specific patterns. + """ + try: + client_flat = np.concatenate([w.flatten() for w in weights]) + global_flat = np.concatenate([w.flatten() for w in global_weights]) + + diff = client_flat - global_flat + + # Adversarial examples create specific high-frequency patterns + # Check for unusual variance patterns + chunk_size = 1000 + chunk_vars = [] + for i in range(0, len(diff), chunk_size): + chunk = diff[i:i+chunk_size] + if len(chunk) > 0: + chunk_vars.append(np.var(chunk)) + + if len(chunk_vars) > 1: + var_of_vars = np.var(chunk_vars) + mean_var = np.mean(chunk_vars) + + # High variance of variances indicates adversarial patterns + if var_of_vars / (mean_var + 1e-8) > 5.0: + return True, 0.70, { + 'variance_ratio': float(var_of_vars / (mean_var + 1e-8)), + 'detection': 'High-frequency variance patterns (adversarial examples)' + } + + return False, 0.0, {} + except Exception as e: + logger.error(f"Error in adversarial examples detection: {e}") + return False, 0.0, {} + + # ========== PRIVACY ATTACK DETECTION ========== + + def detect_gradient_inversion(self, client_id): + """ + Detect GRADIENT_INVERSION attacks. + Track if client excessively accesses gradients. + """ + try: + self.client_access_count[client_id] += 1 + + if self.client_access_count[client_id] > self.detection_config['gradient_access_limit']: + return True, 0.70, { + 'access_count': self.client_access_count[client_id], + 'limit': self.detection_config['gradient_access_limit'], + 'detection': 'Excessive gradient access (potential inversion attack)' + } + + return False, 0.0, {} + except Exception as e: + logger.error(f"Error in gradient inversion detection: {e}") + return False, 0.0, {} + + def detect_membership_inference(self, client_id, metrics): + """ + Detect MEMBERSHIP_INFERENCE attacks. + High confidence predictions can indicate membership inference. + """ + # This is passive - would require analyzing prediction patterns + # For now, just a placeholder + return False, 0.0, {} + + def detect_property_inference(self, client_id): + """ + Detect PROPERTY_INFERENCE attacks. + Analyzing weight statistics can reveal dataset properties. + """ + # This is also passive - placeholder + return False, 0.0, {} + + # ========== UPDATE PROCESSING ========== + def process_update(self, client_id, weights, metrics): - """Process and analyze client update""" - logger.info(f"Processing update from {client_id} (Round {self.round})") + """ + Process client update with comprehensive attack detection. + Returns (accepted, reason). + """ + if client_id not in self.client_states: + logger.error(f"Update from unregistered client: {client_id}") + return False, "Client not registered" + + logger.info(f"Processing update from {client_id}") + # Update counter self.client_states[client_id]['updates_received'] += 1 self.client_states[client_id]['last_metrics'] = metrics - # Security analysis + # Get global weights for comparison global_weights = self.global_model.get_weights() - is_poisoned, confidence, detection_details = self.detect_poisoning(client_id, weights, global_weights) - # Log the analysis - analysis = { - 'round': self.round, - 'client_id': client_id, - 'timestamp': datetime.now().isoformat(), - 'is_poisoned': is_poisoned, - 'confidence': float(confidence), - 'detection_details': detection_details, - 'metrics': metrics - } - self._log_audit('UPDATE_ANALYZED', analysis) + # Run all detection methods + attacks_detected = [] - if is_poisoned: - logger.warning(f"ATTACK DETECTED: {client_id} in round {self.round} (confidence={confidence:.2f})") - self.detected_attacks.append(analysis) - self.client_states[client_id]['attacks_detected'].append({ - 'round': self.round, - 'confidence': confidence, - 'details': detection_details + # INTEGRITY ATTACKS + detected, conf, details = self.detect_poisoning(client_id, weights, global_weights) + if detected: + attacks_detected.append({ + 'type': 'POISONING', + 'confidence': conf, + 'details': details + }) + + detected, conf, details = self.detect_sign_flip(client_id, weights, global_weights) + if detected: + attacks_detected.append({ + 'type': 'SIGN_FLIP', + 'confidence': conf, + 'details': details + }) + + detected, conf, details = self.detect_gaussian_noise(client_id, weights, global_weights) + if detected: + attacks_detected.append({ + 'type': 'GAUSSIAN_NOISE', + 'confidence': conf, + 'details': details + }) + + detected, conf, details = self.detect_backdoor(client_id, weights, global_weights) + if detected: + attacks_detected.append({ + 'type': 'BACKDOOR', + 'confidence': conf, + 'details': details + }) + + detected, conf, details = self.detect_model_replacement(client_id, weights, global_weights) + if detected: + attacks_detected.append({ + 'type': 'MODEL_REPLACEMENT', + 'confidence': conf, + 'details': details + }) + + detected, conf, details = self.detect_malicious_aggregation(client_id, weights, global_weights) + if detected: + attacks_detected.append({ + 'type': 'MALICIOUS_AGGREGATION', + 'confidence': conf, + 'details': details + }) + + detected, conf, details = self.detect_model_drift(client_id, weights) + if detected: + attacks_detected.append({ + 'type': 'MODEL_DRIFT', + 'confidence': conf, + 'details': details + }) + + detected, conf, details = self.detect_free_riding(client_id, weights, global_weights) + if detected: + attacks_detected.append({ + 'type': 'FREE_RIDING', + 'confidence': conf, + 'details': details + }) + + detected, conf, details = self.detect_data_poisoning(client_id, metrics) + if detected: + attacks_detected.append({ + 'type': 'DATA_POISONING', + 'confidence': conf, + 'details': details }) + + detected, conf, details = self.detect_adversarial_examples(client_id, weights, global_weights) + if detected: + attacks_detected.append({ + 'type': 'ADVERSARIAL_EXAMPLES', + 'confidence': conf, + 'details': details + }) + + # PRIVACY ATTACKS + detected, conf, details = self.detect_gradient_inversion(client_id) + if detected: + attacks_detected.append({ + 'type': 'GRADIENT_INVERSION', + 'confidence': conf, + 'details': details + }) + + # Log and handle attacks + if len(attacks_detected) > 0: + logger.warning(f"šŸ”“ ATTACK DETECTED from {client_id}: {len(attacks_detected)} attack(s)") + for attack in attacks_detected: + logger.warning(f" └─ {attack['type']} (confidence={attack['confidence']:.2f})") + self.client_states[client_id]['attack_types'][attack['type']] += 1 + + # Store attack in history + attack_entry = { + 'round': self.round, + 'client_id': client_id, + 'attacks_detected': attacks_detected, + 'metrics': metrics, + 'timestamp': datetime.now().isoformat() + } + self.detected_attacks.append(attack_entry) + self.client_states[client_id]['attacks_detected'].append(attack_entry) + + # REJECT the update self.client_states[client_id]['updates_rejected'] += 1 - return False, "POISONING_DETECTED" + self._log_audit('UPDATE_REJECTED', { + 'client_id': client_id, + 'round': self.round, + 'attacks': [a['type'] for a in attacks_detected] + }) + + return False, f"Attack detected: {', '.join([a['type'] for a in attacks_detected])}" - # Store valid update - self.client_updates[client_id] = { - 'weights': weights, - 'metrics': metrics, - 'timestamp': datetime.now() - } + # ACCEPT update + logger.info(f"āœ“ Update from {client_id} ACCEPTED") + self.client_updates[client_id] = weights self.client_states[client_id]['updates_accepted'] += 1 - logger.info(f"Update accepted from {client_id}") - return True, "ACCEPTED" + + # Store weights for historical analysis + if len(self.weight_history) == 0 or self.round not in [h.get('round') for h in self.weight_history]: + self.weight_history.append({'round': self.round}) + + for hist in self.weight_history: + if hist.get('round') == self.round: + hist[client_id] = weights + break + + self._log_audit('UPDATE_ACCEPTED', { + 'client_id': client_id, + 'round': self.round, + 'metrics': metrics + }) + + # Note: Aggregation is triggered manually via control.py in the test workflow + # Auto-aggregation is disabled to allow manual control of training rounds + logger.info(f"āœ“ Update accepted from {client_id} ({len(self.client_updates)}/{len(self.client_states)} received)") + + return True, "Accepted" - def aggregate_model(self): - """Aggregate accepted updates using FedAvg""" - if not self.client_updates: + def aggregate_updates(self): + """Aggregate client updates using FedAvg""" + if len(self.client_updates) == 0: logger.warning("No updates to aggregate") return False - logger.info(f"Starting aggregation with {len(self.client_updates)} clients") + logger.info(f"Aggregating {len(self.client_updates)} updates") + # FedAvg: Weighted average by num_samples + num_layers = len(self.global_model.get_weights()) new_weights = [np.zeros_like(w) for w in self.global_model.get_weights()] total_samples = 0 - for client_id, data in self.client_updates.items(): - client_weights = data['weights'] + for client_id, client_weights in self.client_updates.items(): num_samples = self.client_states[client_id]['num_samples'] for i, w in enumerate(client_weights): @@ -316,6 +814,13 @@ def aggregate_model(self): # Clear for next round self.client_updates.clear() + + # Clear the tracking of which clients have been served + self.clients_served_this_round.clear() + + # Reset client signals so they can wait for the next round + self.reset_client_signals() + self.round += 1 return True @@ -332,6 +837,36 @@ def reset_client_signals(self): for client_id in self.client_states: self.waiting_clients[client_id].clear() + def reset_round_state(self): + """ + ADDED: Reset server state for current round. + This is needed between attack scenarios in testing. + """ + logger.info("šŸ”„ Resetting round state (keeping client registrations)") + + # Clear pending updates + self.client_updates.clear() + + # Clear the served clients tracking for new round + self.clients_served_this_round.clear() + + # Reset signals + self.reset_client_signals() + + # Clear attacks for THIS round only (keep history) + self.detected_attacks = [a for a in self.detected_attacks if a['round'] != self.round] + + # Reset per-client round stats (but keep overall stats) + for client_id in self.client_states: + # Don't reset updates_received, updates_accepted, updates_rejected + # Just clear the current round's attack detections + self.client_states[client_id]['attacks_detected'] = [ + a for a in self.client_states[client_id]['attacks_detected'] + if a['round'] != self.round + ] + + logger.info("āœ“ Round state reset complete") + def _log_audit(self, event_type, data): """Log event to audit trail""" audit_entry = { @@ -344,7 +879,7 @@ def _log_audit(self, event_type, data): def save_round_report(self): """Save detailed report of current round""" report = { - 'round': self.round - 1, # Previous round (now completed) + 'round': self.round - 1, 'timestamp': datetime.now().isoformat(), 'clients': dict(self.client_states), 'attacks_detected': self.detected_attacks, @@ -382,7 +917,6 @@ def init_client(): client_id = data.get('client_id') num_samples = data.get('num_samples', 0) - # MODIFIED: Include token and public key response, status_code = server.register_client( client_id, num_samples, @@ -393,7 +927,7 @@ def init_client(): @app.route('/get_model', methods=['POST']) def get_model(): - """MODIFIED: Send SIGNED global model weights to client""" + """Send SIGNED global model weights to client""" try: weights = [w.tolist() for w in server.global_model.get_weights()] payload_content = { @@ -401,7 +935,7 @@ def get_model(): 'round': server.round } - # ADDED: Sign the payload + # Sign the payload payload_bytes = json.dumps(payload_content, sort_keys=True).encode() signature = server.server_private_key.sign(payload_bytes) signature_b64 = base64.b64encode(signature).decode('utf-8') @@ -416,11 +950,11 @@ def get_model(): @app.route('/submit_update', methods=['POST']) def submit_update(): - """MODIFIED: Receive SIGNED model update from client""" + """Receive SIGNED model update from client with attack detection""" try: data = request.json - # ADDED: Handle signed payload + # Handle signed payload if 'payload' in data and 'signature' in data: payload_content = data['payload'] signature = data['signature'] @@ -437,7 +971,7 @@ def submit_update(): weights = [np.array(w) for w in payload_content.get('weights', [])] metrics = payload_content.get('metrics', {}) else: - # Backward compatibility: no signature + # Backward compatibility client_id = data.get('client_id') weights = [np.array(w) for w in data.get('weights', [])] metrics = data.get('metrics', {}) @@ -461,26 +995,62 @@ def wait_for_round(): data = request.json client_id = data.get('client_id') - # Get or create event for this client + # Ensure the client is registered + if client_id not in server.client_states: + logger.warning(f"Client {client_id} not registered, rejecting wait request") + return jsonify({'status': 'not_registered'}), 403 + + # Check if this client has already been served for the current round + if client_id in server.clients_served_this_round: + logger.debug(f"Client {client_id} already served for round {server.round}, waiting for next round...") + # Client should wait - don't return signal yet + # This prevents the tight loop problem + return jsonify({'status': 'already_served_this_round'}), 200 + + # Get or create the event for this client event = server.waiting_clients[client_id] - # Wait for signal (with timeout) + logger.debug(f"Client {client_id} waiting for round signal...") signaled = event.wait(timeout=300) if signaled: - # Reset the event for next round + # Mark this client as served for this round + server.clients_served_this_round.add(client_id) + + # Clear the event for this specific client so they don't receive signal again event.clear() + + logger.info(f"āœ“ Client {client_id} received training signal for round {server.round}") return jsonify({'status': 'go_train', 'round': server.round}) else: + logger.warning(f"ā±ļø Client {client_id} timeout waiting for signal") return jsonify({'status': 'timeout'}), 408 @app.route('/trigger_round', methods=['POST']) def trigger_round(): - """Trigger a training round - signal all waiting clients""" - logger.info(f"Triggering training round {server.round}") + """Trigger a training round""" + logger.info(f"šŸ”” TRIGGERING training round {server.round}") + logger.info(f"šŸ“Š Registered clients: {list(server.client_states.keys())}") server.signal_clients_for_round() + logger.info(f"āœ“ All {len(server.client_states)} clients signaled to start training") return jsonify({'status': 'triggered', 'round': server.round}), 200 +@app.route('/trigger_aggregation', methods=['POST']) +def trigger_aggregation(): + """Manually trigger model aggregation""" + logger.info("šŸ“Š Manual aggregation triggered") + if server.aggregate_updates(): + return jsonify({'status': 'aggregated', 'round': server.round}), 200 + else: + return jsonify({'status': 'no_updates', 'round': server.round}), 200 + +@app.route('/reset_round', methods=['POST']) +def reset_round(): + """ADDED: Reset round state (for testing between attack scenarios)""" + logger.info("Received request to reset round state") + server.reset_round_state() + return jsonify({'status': 'reset', 'round': server.round}), 200 + @app.route('/status', methods=['GET']) def status(): """Get detailed server status""" @@ -495,12 +1065,18 @@ def status(): @app.route('/security/status', methods=['GET']) def security_status(): """Security dashboard""" + attack_summary = defaultdict(int) + for attack in server.detected_attacks: + for detected in attack['attacks_detected']: + attack_summary[detected['type']] += 1 + return jsonify({ 'monitoring': 'ACTIVE', 'total_attacks_detected': len(server.detected_attacks), - 'attacks': server.detected_attacks + 'attack_types_summary': dict(attack_summary), + 'recent_attacks': server.detected_attacks[-10:] if len(server.detected_attacks) > 10 else server.detected_attacks }) if __name__ == '__main__': - logger.info("Starting FL Server") + logger.info("Starting FL Server with Multi-Attack Detection") app.run(host='0.0.0.0', port=5000, debug=False, threaded=True) \ No newline at end of file diff --git a/fl-project/src/utils/control.py b/fl-project/src/utils/control.py index acad1c3..55f10b5 100644 --- a/fl-project/src/utils/control.py +++ b/fl-project/src/utils/control.py @@ -63,31 +63,68 @@ def print_status(self): metrics = state['last_metrics'] print(f" Metrics: loss={metrics.get('loss', 0):.4f}, accuracy={metrics.get('accuracy', 0):.4f}") if state['attacks_detected']: - print(f" ATTACKS DETECTED: {len(state['attacks_detected'])}") + print(f" āš ļø ATTACKS DETECTED: {len(state['attacks_detected'])}") + # Show attack types if available + if 'attack_types' in state: + attack_summary = [] + for attack_type, count in state['attack_types'].items(): + if count > 0: + attack_summary.append(f"{attack_type}({count})") + if attack_summary: + print(f" Attack Types: {', '.join(attack_summary)}") print(f"\nPending updates: {status['pending_updates']}") print(f"Attacks detected this round: {status['attacks_detected_this_round']}") print(f"Total attacks detected: {status['total_attacks_detected']}") if status['attacks_detected_this_round'] > 0: - sec_status = requests.get(f'{self.server_url}/security/status').json() - print(f"\nDetected attacks:") - for attack in sec_status['attacks']: - if attack['round'] == status['round']: - print(f" Round {attack['round']}: {attack['client_id']} (confidence={attack['confidence']:.2f})") + try: + sec_status = requests.get(f'{self.server_url}/security/status').json() + print(f"\nšŸ”“ Recent Attack Details:") + + # Show attack type summary if available + if 'attack_types_summary' in sec_status: + print(f"\nAttack Types Summary:") + for attack_type, count in sec_status['attack_types_summary'].items(): + print(f" - {attack_type}: {count} times") + + # Show recent attacks + for attack in sec_status.get('recent_attacks', []): + if attack['round'] == status['round']: + client_id = attack.get('client_id', 'unknown') + confidence = attack.get('overall_confidence', 0) + print(f"\n Round {attack['round']}: {client_id} (confidence={confidence:.2f})") + + # Show detailed attack types + if 'attacks_detected' in attack: + for det in attack['attacks_detected']: + attack_type = det.get('type', 'UNKNOWN') + det_conf = det.get('confidence', 0) + print(f" └─ {attack_type} (confidence={det_conf:.2f})") + + # Show detection details if available + if 'details' in det and det['details']: + details = det['details'] + if 'detection_methods' in details: + print(f" Methods: {', '.join(details['detection_methods'])}") + if 'detection' in details: + print(f" {details['detection']}") + except Exception as e: + logger.error(f"Error getting security status: {e}") def signal_training_round(self): - """Tell server to signal all clients to begin training""" - # Note: In this simplified version, the server signals clients - # In a real implementation, you might use a dedicated endpoint - logger.info(f"Signaling clients to begin training") + """Signal server to trigger training round""" try: - # Make a request to the server to signal clients - # This is handled by the /wait_for_round endpoint with threading - return True + response = requests.post( + f'{self.server_url}/trigger_round', + timeout=5 + ) + if response.status_code == 200: + logger.info("āœ“ Training round triggered at server") + return True except Exception as e: - logger.error(f"Error signaling clients: {e}") - return False + logger.error(f"Error triggering round: {e}") + return False def wait_for_updates(self, timeout=60): """Wait for clients to submit updates""" @@ -107,10 +144,38 @@ def wait_for_updates(self, timeout=60): def trigger_aggregation(self): """Trigger model aggregation at the server""" - # In this simplified version, aggregation is triggered manually - # You would need to add a /trigger_round endpoint if needed - logger.info("Aggregation complete (clients' updates have been processed)") - return True + try: + response = requests.post( + f'{self.server_url}/trigger_aggregation', + timeout=5 + ) + if response.status_code == 200: + data = response.json() + logger.info(f"Aggregation triggered - now at round {data['round']}") + return True + else: + logger.error(f"Failed to trigger aggregation: {response.status_code}") + return False + except Exception as e: + logger.error(f"Error triggering aggregation: {e}") + return False + + def reset_round(self): + """Reset the server's round state (for testing between attack scenarios)""" + try: + response = requests.post( + f'{self.server_url}/reset_round', + timeout=5 + ) + if response.status_code == 200: + logger.info("āœ“ Server round state reset") + return True + else: + logger.error(f"Failed to reset round: {response.status_code}") + return False + except Exception as e: + logger.error(f"Error resetting round: {e}") + return False def interactive_monitor(self, interval=10): """Continuous monitoring mode""" @@ -122,86 +187,105 @@ def interactive_monitor(self, interval=10): time.sleep(interval) except KeyboardInterrupt: logger.info("Monitoring stopped") - - def signal_training_round(self): - """Signal server to trigger training round""" + + def trigger_round(self): + """Signal the server to start a new training round""" try: - response = requests.post( - f'{self.server_url}/trigger_round', - timeout=5 - ) + # Assuming the endpoint is /start_round based on your logs + response = requests.post(f'{self.server_url}/trigger_round') if response.status_code == 200: logger.info("Training round triggered at server") return True + else: + logger.error(f"Failed to start round: {response.text}") + return False except Exception as e: logger.error(f"Error triggering round: {e}") - return False - + return False + + def trigger_aggregation(self): + """Signal the server to aggregate updates""" + try: + response = requests.post(f'{self.server_url}/trigger_aggregation') + if response.status_code == 200: + logger.info("Aggregation triggered") + return True + else: + logger.error(f"Failed to aggregate: {response.text}") + return False + except Exception as e: + logger.error(f"Error triggering aggregation: {e}") + return False + def run_training_sequence(self, num_rounds=5, wait_time=60): - """ - Full training sequence: - 1. Signal server to trigger round (which signals waiting clients) - 2. Wait for updates - 3. Repeat - """ + """Run a sequence of training rounds with smart polling""" logger.info(f"Starting training sequence for {num_rounds} rounds") - - for round_num in range(1, num_rounds + 1): - logger.info(f"ROUND {round_num}/{num_rounds}") - print("="*70) - - # Get status before round - status = self.get_status() - if not status: - logger.error("Could not get status") - continue - - current_round = status['round'] - logger.info(f"Triggering round {current_round} - signaling clients to train") - - # Signal server to trigger training - self.signal_training_round() - - # Wait for updates - logger.info(f"Waiting {wait_time}s for client updates...") - time.sleep(wait_time) - - # Check status - status = self.get_status() - if not status: - logger.error("Could not get status") - continue - - # Print round summary - total_received = sum(c['updates_received'] for c in status['clients'].values()) - total_accepted = sum(c['updates_accepted'] for c in status['clients'].values()) - total_rejected = sum(c['updates_rejected'] for c in status['clients'].values()) - - logger.info(f"Round {current_round} Summary:") - logger.info(f" - Updates received: {total_received}") - logger.info(f" - Updates accepted: {total_accepted}") - logger.info(f" - Updates rejected: {total_rejected}") - - if status['attacks_detected_this_round'] > 0: - logger.warning(f" - ATTACKS DETECTED: {status['attacks_detected_this_round']}") - - # Print which clients were detected - for client_id, state in status['clients'].items(): - if state['attacks_detected']: - for attack in state['attacks_detected']: - if attack['round'] == current_round: - logger.warning(f" {client_id} (confidence={attack['confidence']:.2f})") - - logger.info(f"Round {current_round} completed") + + for i in range(num_rounds): + logger.info("="*70) + logger.info(f"šŸ”„ ROUND {i+1}/{num_rounds}") + + # 1. Trigger the round + if not self.trigger_round(): + logger.error("Failed to trigger round - stopping sequence") + break + + # 2. Smart Wait Loop (Optimization) + # We poll the server status to see if all clients have replied + logger.info(f"ā³ Waiting for updates (polling every 2s, max {wait_time}s)...") + start_time = time.time() + + while (time.time() - start_time) < wait_time: + status = self.get_status() + if status: + # In server.py, 'pending_updates' is list(server.client_updates.keys()) + # This represents valid updates sitting in the buffer waiting for aggregation. + updates_received = len(status.get('pending_updates', [])) + total_clients = len(status.get('clients', {})) + + # If we have updates from all registered clients, stop waiting! + if total_clients > 0 and updates_received >= total_clients: + logger.info(f"āœ“ All clients reported ({updates_received}/{total_clients}). Proceeding immediately.") + break + + # Poll interval + time.sleep(2) + + # 3. Trigger Aggregation + logger.info("šŸ“Š Triggering model aggregation...") + self.trigger_aggregation() + + # 4. Print Summary + self.print_status() + + # Short buffer before next round to ensure logs are clean time.sleep(2) - + + logger.info("="*70) logger.info(f"Training sequence completed {num_rounds} rounds") - self.print_status() + logger.info("="*70) def main(): - parser = argparse.ArgumentParser(description='FL Control & Orchestration') + parser = argparse.ArgumentParser( + description='FL Control & Orchestration', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Show current status + python control.py --mode status + + # Monitor continuously + python control.py --mode monitor --interval 10 + + # Reset round state (for testing between attack scenarios) + python control.py --mode reset + + # Run 5 training rounds + python control.py --mode train --rounds 5 --wait 100 + """ + ) parser.add_argument('--mode', default='monitor', - choices=['status', 'monitor', 'train'], + choices=['status', 'monitor', 'train', 'reset'], help='Operation mode') parser.add_argument('--rounds', type=int, default=5, help='Number of rounds for training mode') @@ -214,15 +298,26 @@ def main(): controller = FLController() + # Check server health + print("\n" + "="*70) + print("FL CONTROLLER - Starting") + print("="*70) + if not controller.check_server_health(): + logger.error("Server is not available. Please start the server first.") sys.exit(1) + # Execute requested mode if args.mode == 'status': controller.print_status() elif args.mode == 'monitor': controller.interactive_monitor(interval=args.interval) + elif args.mode == 'reset': + controller.reset_round() + controller.print_status() + elif args.mode == 'train': controller.run_training_sequence(num_rounds=args.rounds, wait_time=args.wait) diff --git a/fl-project/src/utils/metrics.py b/fl-project/src/utils/metrics.py new file mode 100644 index 0000000..080edfd --- /dev/null +++ b/fl-project/src/utils/metrics.py @@ -0,0 +1,156 @@ +""" +Prometheus metrics exporter for FL Server +Exposes metrics that Grafana dashboards can query +""" + +from prometheus_client import Counter, Gauge, Histogram, generate_latest, REGISTRY +from flask import Response +import logging + +logger = logging.getLogger("METRICS") + +# ============================================================================ +# TRAINING METRICS +# ============================================================================ + +# Current training round +fl_training_round = Gauge( + 'fl_training_round', + 'Current federated learning training round' +) + +# Update counters +fl_updates_received = Counter( + 'fl_updates_received', + 'Total number of updates received from clients', + ['client_id'] +) + +fl_updates_accepted = Counter( + 'fl_updates_accepted', + 'Total number of updates accepted', + ['client_id'] +) + +fl_updates_rejected = Counter( + 'fl_updates_rejected', + 'Total number of updates rejected', + ['client_id'] +) + +# ============================================================================ +# SECURITY METRICS +# ============================================================================ + +# Attack detection counters +fl_attacks_detected = Counter( + 'fl_attacks_detected', + 'Total number of attacks detected', + ['client_id', 'attack_type'] +) + +fl_attacks_high_confidence = Counter( + 'fl_attacks_high_confidence', + 'Number of high-confidence attacks (>90%)', + ['client_id', 'attack_type'] +) + +# Current threat level (0-100) +fl_threat_level = Gauge( + 'fl_threat_level', + 'Current overall threat level (0-100)' +) + +# ============================================================================ +# CLIENT METRICS +# ============================================================================ + +fl_clients_registered = Gauge( + 'fl_clients_registered', + 'Number of registered clients' +) + +fl_client_last_update = Gauge( + 'fl_client_last_update', + 'Timestamp of last update from client', + ['client_id'] +) + +# ============================================================================ +# MODEL PERFORMANCE METRICS +# ============================================================================ + +fl_model_loss = Gauge( + 'fl_model_loss', + 'Training loss reported by client', + ['client_id'] +) + +fl_model_accuracy = Gauge( + 'fl_model_accuracy', + 'Training accuracy reported by client', + ['client_id'] +) + +# ============================================================================ +# HELPER FUNCTIONS +# ============================================================================ + +def record_update_received(client_id): + """Record that an update was received from a client""" + fl_updates_received.labels(client_id=client_id).inc() + logger.debug(f"Metrics: Update received from {client_id}") + +def record_update_accepted(client_id): + """Record that an update was accepted""" + fl_updates_accepted.labels(client_id=client_id).inc() + logger.debug(f"Metrics: Update accepted from {client_id}") + +def record_update_rejected(client_id): + """Record that an update was rejected""" + fl_updates_rejected.labels(client_id=client_id).inc() + logger.debug(f"Metrics: Update rejected from {client_id}") + +def record_attack_detected(client_id, attack_type, confidence): + """Record an attack detection""" + fl_attacks_detected.labels(client_id=client_id, attack_type=attack_type).inc() + + if confidence > 0.9: + fl_attacks_high_confidence.labels(client_id=client_id, attack_type=attack_type).inc() + + logger.debug(f"Metrics: Attack detected - {client_id}/{attack_type} (conf={confidence:.2f})") + +def update_round(round_number): + """Update the current round number""" + fl_training_round.set(round_number) + logger.debug(f"Metrics: Round updated to {round_number}") + +def update_client_count(count): + """Update the number of registered clients""" + fl_clients_registered.set(count) + logger.debug(f"Metrics: Client count updated to {count}") + +def update_client_metrics(client_id, loss, accuracy): + """Update client training metrics""" + if loss is not None: + fl_model_loss.labels(client_id=client_id).set(loss) + if accuracy is not None: + fl_model_accuracy.labels(client_id=client_id).set(accuracy) + logger.debug(f"Metrics: Client {client_id} - loss={loss}, acc={accuracy}") + +def update_threat_level(level): + """Update overall threat level (0-100)""" + fl_threat_level.set(level) + logger.debug(f"Metrics: Threat level updated to {level}") + +def get_metrics(): + """Return Prometheus metrics in text format""" + return generate_latest(REGISTRY) + +def create_metrics_endpoint(app): + """Create /metrics endpoint for Prometheus to scrape""" + @app.route('/metrics') + def metrics(): + return Response(get_metrics(), mimetype='text/plain') + + logger.info("āœ“ Prometheus metrics endpoint created at /metrics") \ No newline at end of file diff --git a/fl-project/src/utils/structured_logging.py b/fl-project/src/utils/structured_logging.py new file mode 100644 index 0000000..c8ef907 --- /dev/null +++ b/fl-project/src/utils/structured_logging.py @@ -0,0 +1,155 @@ +""" +Structured logging for Loki integration +Emits JSON logs that Grafana/Loki can parse and query +""" + +import json +import logging +from datetime import datetime + +class StructuredLogger: + """ + Logger that emits structured JSON logs for Loki + These logs can be queried in Grafana dashboards + """ + + def __init__(self, name="FL-STRUCTURED"): + self.logger = logging.getLogger(name) + + def _emit(self, event_type, data): + """Emit a structured log entry""" + log_entry = { + 'timestamp': datetime.now().isoformat(), + 'event_type': event_type, + **data + } + # Log as JSON string so Loki can parse it + self.logger.info(json.dumps(log_entry)) + + # ======================================================================== + # TRAINING EVENTS + # ======================================================================== + + def round_started(self, round_num, num_clients): + """Log that a training round has started""" + self._emit('ROUND_START', { + 'round': round_num, + 'num_clients': num_clients, + 'message': f'Round {round_num} started with {num_clients} clients' + }) + + def round_completed(self, round_num, updates_accepted, updates_rejected): + """Log that a training round has completed""" + self._emit('ROUND_END', { + 'round': round_num, + 'updates_accepted': updates_accepted, + 'updates_rejected': updates_rejected, + 'message': f'Round {round_num} completed: {updates_accepted} accepted, {updates_rejected} rejected' + }) + + def aggregation_completed(self, round_num, num_updates): + """Log that model aggregation is complete""" + self._emit('AGGREGATION_COMPLETED', { + 'round': round_num, + 'num_updates': num_updates, + 'message': f'Aggregated {num_updates} updates for round {round_num}' + }) + + # ======================================================================== + # SECURITY EVENTS + # ======================================================================== + + def attack_detected(self, round_num, client_id, attack_type, confidence, details=None): + """Log an attack detection""" + self._emit('ATTACK_DETECTED', { + 'round': round_num, + 'client_id': client_id, + 'attack_type': attack_type, + 'confidence': confidence, + 'severity': 'HIGH' if confidence > 0.9 else 'MEDIUM' if confidence > 0.7 else 'LOW', + 'details': details or {}, + 'message': f'Attack detected: {client_id} - {attack_type} (confidence={confidence:.2f})' + }) + + def attack_rejected(self, round_num, client_id, attack_types): + """Log that an update was rejected due to detected attacks""" + self._emit('ATTACK_REJECTED', { + 'round': round_num, + 'client_id': client_id, + 'attack_types': attack_types, + 'message': f'Update rejected from {client_id} due to: {", ".join(attack_types)}' + }) + + def security_alert(self, alert_type, client_id, message, severity='MEDIUM'): + """Log a general security alert""" + self._emit('SECURITY_ALERT', { + 'alert_type': alert_type, + 'client_id': client_id, + 'severity': severity, + 'message': message + }) + + # ======================================================================== + # CLIENT EVENTS + # ======================================================================== + + def client_registered(self, client_id, num_samples): + """Log client registration""" + self._emit('CLIENT_REGISTERED', { + 'client_id': client_id, + 'num_samples': num_samples, + 'message': f'Client {client_id} registered with {num_samples} samples' + }) + + def client_update_received(self, round_num, client_id, loss, accuracy): + """Log that an update was received from a client""" + self._emit('CLIENT_UPDATE_RECEIVED', { + 'round': round_num, + 'client_id': client_id, + 'loss': loss, + 'accuracy': accuracy, + 'message': f'Update received from {client_id}: loss={loss:.4f}, acc={accuracy:.4f}' + }) + + def client_update_accepted(self, round_num, client_id): + """Log that a client's update was accepted""" + self._emit('CLIENT_UPDATE_ACCEPTED', { + 'round': round_num, + 'client_id': client_id, + 'message': f'Update accepted from {client_id}' + }) + + def client_update_rejected(self, round_num, client_id, reason): + """Log that a client's update was rejected""" + self._emit('CLIENT_UPDATE_REJECTED', { + 'round': round_num, + 'client_id': client_id, + 'reason': reason, + 'message': f'Update rejected from {client_id}: {reason}' + }) + + # ======================================================================== + # MODEL EVENTS + # ======================================================================== + + def model_performance(self, round_num, global_metrics): + """Log global model performance metrics""" + self._emit('MODEL_PERFORMANCE', { + 'round': round_num, + 'metrics': global_metrics, + 'message': f'Global model performance at round {round_num}' + }) + + # ======================================================================== + # SYSTEM EVENTS + # ======================================================================== + + def system_event(self, event_type, message, data=None): + """Log a general system event""" + self._emit(event_type, { + 'message': message, + **(data or {}) + }) + +# Global structured logger instance +structured_logger = StructuredLogger() \ No newline at end of file