aboutsummaryrefslogtreecommitdiffstats
path: root/src/fsec-optimize/optimizer.c
blob: f5bef33e28f8ca3124dd45eb622acc99f907ebd6 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
/*
 * Copyright (C) 2014-2018 Firejail Authors
 *
 * This file is part of firejail project
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License along
 * with this program; if not, write to the Free Software Foundation, Inc.,
 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
*/
#include "fsec_optimize.h"

// From /usr/include/linux/filter.h
//struct sock_filter {	/* Filter block */
//	__u16	code;   /* Actual filter code */
//	__u8	jt;	/* Jump true */
//	__u8	jf;	/* Jump false */
//	__u32	k;      /* Generic multiuse field */
//};


#define LIMIT_BLACKLISTS 4	// we optimize blacklists only if we have more than

static inline int is_blacklist(struct sock_filter *bpf) {
	if (bpf->code == BPF_JMP + BPF_JEQ + BPF_K &&
	    (bpf + 1)->code == BPF_RET + BPF_K &&
	    (bpf + 1)->k == SECCOMP_RET_KILL )
		return 1;
	return 0;
}

static int count_blacklists(struct sock_filter *filter, int entries) {
	int cnt = 0;
	int i;

	for (i = 0; i < (entries - 1); i++, filter++) { // is_blacklist works on two consecutive lines; using entries - 1
		if (is_blacklist(filter))
			cnt++;
	}

	return cnt;
}

typedef struct {
	int to_remove;
	int to_fix_jumps;
} Action;

static int optimize_blacklists(struct sock_filter *filter, int entries) {
	assert(entries);
	assert(filter);
	int i;
	int j;

	// step1: extract information
	Action action[entries];
	memset(&action[0], 0, sizeof(Action) * entries);
	int remove_cnt = 0;
	for (i = 0; i < (entries - 1); i++) { // is_blacklist works on two consecutive lines; using entries - 1
		if (is_blacklist(filter + i)) {
			action[i]. to_fix_jumps = 1;
			i++;
			action[i].to_remove = 1;
			remove_cnt++;
		}
	}

	// step2: remove lines
	struct sock_filter *filter_step2 = duplicate(filter, entries);
	Action action_step2[entries];
	memset(&action_step2[0], 0, sizeof(Action) * entries);
	for (i = 0, j = 0; i < entries; i++) {
		if (!action[i].to_remove) {
			memcpy(&filter_step2[j], &filter[i], sizeof(struct sock_filter));
			memcpy(&action_step2[j], &action[i], sizeof(Action));
			j++;
		}
		else {
			// do nothing, we are removing this line
		}
	}

	// step 3: add the new ret KILL, and recalculate entries
	filter_step2[j].code = BPF_RET + BPF_K;
	filter_step2[j].k == SECCOMP_RET_KILL;
	entries = j + 1;

	// step 4: recalculate jumps
	for (i = 0; i < entries; i++) {
		if (action_step2[i].to_fix_jumps) {
			filter_step2[i].jt = entries - i - 2;
			filter_step2[i].jf = 0;
		}
	}

	// update
	memcpy(filter, filter_step2, entries * sizeof(struct sock_filter));
	free(filter_step2);
	return entries;
}

int optimize(struct sock_filter *filter, int entries) {
	assert(filter);
	assert(entries);

	//**********************************
	// optimize blacklist statements
	//**********************************
	// count "ret KILL"
	int cnt = count_blacklists(filter, entries);
	if (cnt > LIMIT_BLACKLISTS)
		entries = optimize_blacklists(filter, entries);
	return entries;
}

struct sock_filter *duplicate(struct sock_filter *filter, int entries) {
	int len = sizeof(struct sock_filter) * entries;
	struct sock_filter *rv = malloc(len);
	if (!rv) {
		errExit("malloc");
		exit(1);
	}

	memcpy(rv, filter, len);
	return rv;
}