#include "db.h" #include char next_hop_ip[16] = ""; char gateway_ip[16] = ""; char mac_address[18] = ""; int connect_mysql(const char *host, int port, const char *user, const char *pwd, const char *db_name) { if (conn_db == NULL) { conn_db = (MYSQL *)malloc(sizeof(MYSQL)); if (mysql_init(conn_db) != NULL) { mysql_set_character_set(conn_db, "utf8"); // if (mysql_real_connect(conn_db, "localhost", "disen", "disen", "db1", // 3306, NULL, 0) == NULL) if (mysql_real_connect(conn_db, host, user, pwd, db_name, port, NULL, 0) == NULL) { return -1; } } else { return -1; } } return 0; } int query(const char *sql, void (*callback)(MYSQL_ROW row, char (*columns)[30], int cols)) { if (NULL == conn_db) { return -1; } int ret = mysql_real_query(conn_db, sql, strlen(sql)); if (ret != 0) { return -1; } MYSQL_RES *res = mysql_store_result(conn_db); if (res != NULL) { my_ulonglong rows = mysql_num_rows(res); // printf("查找到的数据库的行数为 %llu\n", rows); unsigned int cols = mysql_num_fields(res); char colunm_names[cols][30]; int i = 0; MYSQL_FIELD *field = NULL; while ((field = mysql_fetch_field(res)) != NULL) { strcpy(colunm_names[i++], field->name); } for (int i = 0; i < rows; i++) { MYSQL_ROW row = mysql_fetch_row(res); if (callback != NULL) callback(row, colunm_names, cols); } mysql_free_result(res); } return 0; } int insert(const char *sql, MYSQL_BIND *params) { MYSQL_STMT *stmt = mysql_stmt_init(conn_db); mysql_stmt_prepare(stmt, sql, strlen(sql)); if (params != NULL) { if (mysql_stmt_bind_param(stmt, params)) // 11)错误 0)成功 { mysql_stmt_close(stmt); return -1; } } if (mysql_stmt_execute(stmt) != 0) { mysql_stmt_close(stmt); return -1; } int ret = mysql_stmt_fetch(stmt); mysql_stmt_close(stmt); return ret; // 返回是影响的行数 } int delete (const char *sql, MYSQL_BIND *params) { MYSQL_STMT *stmt = mysql_stmt_init(conn_db); mysql_stmt_prepare(stmt, sql, strlen(sql)); if (params != NULL) { if (mysql_stmt_bind_param(stmt, params)) // 11)错误 0)成功 { mysql_stmt_close(stmt); return -1; } } if (mysql_stmt_execute(stmt) != 0) { mysql_stmt_close(stmt); return -1; } int ret = mysql_stmt_fetch(stmt); mysql_stmt_close(stmt); return ret; // 返回是影响的行数 } int close_mysql() { if (conn_db != NULL) { mysql_close(conn_db); return 0; } return -1; } int result_rows(const char *sql) { if (NULL == conn_db) { return -1; } int ret = mysql_real_query(conn_db, sql, strlen(sql)); if (ret != 0) { return -1; } my_ulonglong rows = 0; MYSQL_RES *res = mysql_store_result(conn_db); if (res != NULL) { rows = mysql_num_rows(res); // printf("查找到的数据库的行数为 %llu\n", rows); } return rows; } void insert_routing_list(const char *ip, const char *mask, const char *nexthop) { char sql[100]; sprintf(sql, "insert into routing_list(ip, mask, nexthop) values('%s', '%s', '%s')", ip, mask, nexthop); if (mysql_real_query(conn_db, sql, strlen(sql)) != 0) { printf("添加路由表失败\n"); } } void delete_routing_list(const char *ip) { char sql[100]; sprintf(sql, "delete from routing_list where ip = '%s'", ip); if (mysql_real_query(conn_db, sql, strlen(sql)) != 0) { printf("删除路由表失败\n"); } } void print_routing_list() { char sql[100]; sprintf(sql, "select * from routing_list"); query(sql, printResult); // 打印查询结果 } void insert_arp_list(const char *ip, const char *mac) { char sql[100]; sprintf(sql, "insert into ip_mac(ip, mac) values('%s', '%s')", ip, mac); if (mysql_real_query(conn_db, sql, strlen(sql)) != 0) { printf("添加ARP表失败\n"); } } extern void update_arp_list_by_ip(const char *ip, const char *mac) { char sql[100]; sprintf(sql, "update ip_mac set mac = '%s' where ip = '%s'", mac, ip); if (mysql_real_query(conn_db, sql, strlen(sql)) != 0) { printf("更新ARP表失败\n"); } } int search_arp_list_if_ip_have(const char *ip) { char sql[100]; sprintf(sql, "select * from ip_mac where ip = '%s'", ip); int ret = result_rows(sql); return ret; } void print_arp_list() { char sql[100]; sprintf(sql, "select * from ip_mac"); query(sql, printResult); // 打印查询结果 } void printResult(MYSQL_ROW row, char (*columns)[30], int cols) { for (int i = 0; i < cols; i++) { printf("%s: %s \t", columns[i], row[i]); } printf("\n"); } void insert_ip_fw(const char *ip) { char sql[100]; sprintf(sql, "insert into ip_fw(ip) values('%s')", ip); if (mysql_real_query(conn_db, sql, strlen(sql)) != 0) { printf("添加黑名单失败\n"); } } int search_ip_fw(const char *ip) { char sql[100]; sprintf(sql, "select * from ip_fw where ip = '%s'", ip); int ret = result_rows(sql); return ret; } void print_ip_fw() { char sql[100]; sprintf(sql, "select * from ip_fw"); query(sql, printResult); // 打印查询结果 } void delete_ip_fw(const char *ip) { char sql[100]; sprintf(sql, "delete from ip_fw where ip = '%s'", ip); if (mysql_real_query(conn_db, sql, strlen(sql)) != 0) { printf("删除黑名单失败\n"); } } int is_in_routing_table(const char *dst_ip) { char sql[100]; sprintf(sql, "select * from routing_list where ip = '%s'", dst_ip); int ret = result_rows(sql); return ret > 0; } void get_next_hop_ip(const char *dst_ip, char *next_hop_ip) { char sql[100]; sprintf(sql, "select nexthop from routing_list where ip = '%s'", dst_ip); query(sql, fetch_next_hop_ip); // fetch_next_hop_ip 是一个回调函数,用于从查询结果中获取下一跳 IP } void get_default_gateway_ip(char *gateway_ip) { char sql[100]; sprintf(sql, "select nexthop from routing_list where ip = '0.0.0.0'"); query(sql, fetch_gateway_ip); // fetch_gateway_ip 是一个回调函数,用于从查询结果中获取默认网关 IP } int is_in_arp_table(const char *ip) { char sql[100]; sprintf(sql, "select * from ip_mac where ip = '%s'", ip); int ret = result_rows(sql); return ret > 0; } void get_mac_address(const char *ip, char *mac_address) { char sql[100]; sprintf(sql, "select mac from ip_mac where ip = '%s'", ip); query(sql, fetch_mac_address); // fetch_mac_address 是一个回调函数,用于从查询结果中获取 MAC 地址 } void fetch_next_hop_ip(MYSQL_ROW row, char (*columns)[30], int cols) { for (int i = 0; i < cols; i++) { if (strcmp(columns[i], "nexthop") == 0) { strcpy(next_hop_ip, row[i]); break; } } } void fetch_gateway_ip(MYSQL_ROW row, char (*columns)[30], int cols) { for (int i = 0; i < cols; i++) { if (strcmp(columns[i], "nexthop") == 0) { strcpy(gateway_ip, row[i]); break; } } } void fetch_mac_address(MYSQL_ROW row, char (*columns)[30], int cols) { for (int i = 0; i < cols; i++) { if (strcmp(columns[i], "mac") == 0) { strcpy(mac_address, row[i]); break; } } }