c-router-emulator/router/db.c

327 lines
7.7 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "db.h"
#include <stdio.h>
char next_hop_ip[16] = "";
char gateway_ip[16] = "";
char mac_address[18] = "";
int connect_mysql(const char *host, int port, const char *user, const char *pwd, const char *db_name)
{
if (conn_db == NULL)
{
conn_db = (MYSQL *)malloc(sizeof(MYSQL));
if (mysql_init(conn_db) != NULL)
{
mysql_set_character_set(conn_db, "utf8");
// if (mysql_real_connect(conn_db, "localhost", "disen", "disen", "db1",
// 3306, NULL, 0) == NULL)
if (mysql_real_connect(conn_db, host, user, pwd, db_name,
port, NULL, 0) == NULL)
{
return -1;
}
}
else
{
return -1;
}
}
return 0;
}
int query(const char *sql, void (*callback)(MYSQL_ROW row, char (*columns)[30], int cols))
{
if (NULL == conn_db)
{
return -1;
}
int ret = mysql_real_query(conn_db, sql, strlen(sql));
if (ret != 0)
{
return -1;
}
MYSQL_RES *res = mysql_store_result(conn_db);
if (res != NULL)
{
my_ulonglong rows = mysql_num_rows(res);
// printf("查找到的数据库的行数为 %llu\n", rows);
unsigned int cols = mysql_num_fields(res);
char colunm_names[cols][30];
int i = 0;
MYSQL_FIELD *field = NULL;
while ((field = mysql_fetch_field(res)) != NULL)
{
strcpy(colunm_names[i++], field->name);
}
for (int i = 0; i < rows; i++)
{
MYSQL_ROW row = mysql_fetch_row(res);
if (callback != NULL)
callback(row, colunm_names, cols);
}
mysql_free_result(res);
}
return 0;
}
int insert(const char *sql, MYSQL_BIND *params)
{
MYSQL_STMT *stmt = mysql_stmt_init(conn_db);
mysql_stmt_prepare(stmt, sql, strlen(sql));
if (params != NULL)
{
if (mysql_stmt_bind_param(stmt, params)) // 11错误 0成功
{
mysql_stmt_close(stmt);
return -1;
}
}
if (mysql_stmt_execute(stmt) != 0)
{
mysql_stmt_close(stmt);
return -1;
}
int ret = mysql_stmt_fetch(stmt);
mysql_stmt_close(stmt);
return ret; // 返回是影响的行数
}
int delete (const char *sql, MYSQL_BIND *params)
{
MYSQL_STMT *stmt = mysql_stmt_init(conn_db);
mysql_stmt_prepare(stmt, sql, strlen(sql));
if (params != NULL)
{
if (mysql_stmt_bind_param(stmt, params)) // 11错误 0成功
{
mysql_stmt_close(stmt);
return -1;
}
}
if (mysql_stmt_execute(stmt) != 0)
{
mysql_stmt_close(stmt);
return -1;
}
int ret = mysql_stmt_fetch(stmt);
mysql_stmt_close(stmt);
return ret; // 返回是影响的行数
}
int close_mysql()
{
if (conn_db != NULL)
{
mysql_close(conn_db);
return 0;
}
return -1;
}
int result_rows(const char *sql)
{
if (NULL == conn_db)
{
return -1;
}
int ret = mysql_real_query(conn_db, sql, strlen(sql));
if (ret != 0)
{
return -1;
}
my_ulonglong rows = 0;
MYSQL_RES *res = mysql_store_result(conn_db);
if (res != NULL)
{
rows = mysql_num_rows(res);
// printf("查找到的数据库的行数为 %llu\n", rows);
}
return rows;
}
void insert_routing_list(const char *ip, const char *mask, const char *nexthop)
{
char sql[100];
sprintf(sql, "insert into routing_list(ip, mask, nexthop) values('%s', '%s', '%s')", ip, mask, nexthop);
if (mysql_real_query(conn_db, sql, strlen(sql)) != 0)
{
printf("添加路由表失败\n");
}
}
void delete_routing_list(const char *ip)
{
char sql[100];
sprintf(sql, "delete from routing_list where ip = '%s'", ip);
if (mysql_real_query(conn_db, sql, strlen(sql)) != 0)
{
printf("删除路由表失败\n");
}
}
void print_routing_list()
{
char sql[100];
sprintf(sql, "select * from routing_list");
query(sql, printResult); // 打印查询结果
}
void insert_arp_list(const char *ip, const char *mac)
{
char sql[100];
sprintf(sql, "insert into ip_mac(ip, mac) values('%s', '%s')", ip, mac);
if (mysql_real_query(conn_db, sql, strlen(sql)) != 0)
{
printf("添加ARP表失败\n");
}
}
extern void update_arp_list_by_ip(const char *ip, const char *mac)
{
char sql[100];
sprintf(sql, "update ip_mac set mac = '%s' where ip = '%s'", mac, ip);
if (mysql_real_query(conn_db, sql, strlen(sql)) != 0)
{
printf("更新ARP表失败\n");
}
}
int search_arp_list_if_ip_have(const char *ip)
{
char sql[100];
sprintf(sql, "select * from ip_mac where ip = '%s'", ip);
int ret = result_rows(sql);
return ret;
}
void print_arp_list()
{
char sql[100];
sprintf(sql, "select * from ip_mac");
query(sql, printResult); // 打印查询结果
}
void printResult(MYSQL_ROW row, char (*columns)[30], int cols)
{
for (int i = 0; i < cols; i++)
{
printf("%s: %s \t", columns[i], row[i]);
}
printf("\n");
}
void insert_ip_fw(const char *ip)
{
char sql[100];
sprintf(sql, "insert into ip_fw(ip) values('%s')", ip);
if (mysql_real_query(conn_db, sql, strlen(sql)) != 0)
{
printf("添加黑名单失败\n");
}
}
int search_ip_fw(const char *ip)
{
char sql[100];
sprintf(sql, "select * from ip_fw where ip = '%s'", ip);
int ret = result_rows(sql);
return ret;
}
void print_ip_fw()
{
char sql[100];
sprintf(sql, "select * from ip_fw");
query(sql, printResult); // 打印查询结果
}
void delete_ip_fw(const char *ip)
{
char sql[100];
sprintf(sql, "delete from ip_fw where ip = '%s'", ip);
if (mysql_real_query(conn_db, sql, strlen(sql)) != 0)
{
printf("删除黑名单失败\n");
}
}
int is_in_routing_table(const char *dst_ip)
{
char sql[100];
sprintf(sql, "select * from routing_list where ip = '%s'", dst_ip);
int ret = result_rows(sql);
return ret > 0;
}
void get_next_hop_ip(const char *dst_ip, char *next_hop_ip)
{
char sql[100];
sprintf(sql, "select nexthop from routing_list where ip = '%s'", dst_ip);
query(sql, fetch_next_hop_ip); // fetch_next_hop_ip 是一个回调函数,用于从查询结果中获取下一跳 IP
}
void get_default_gateway_ip(char *gateway_ip)
{
char sql[100];
sprintf(sql, "select nexthop from routing_list where ip = '0.0.0.0'");
query(sql, fetch_gateway_ip); // fetch_gateway_ip 是一个回调函数,用于从查询结果中获取默认网关 IP
}
int is_in_arp_table(const char *ip)
{
char sql[100];
sprintf(sql, "select * from ip_mac where ip = '%s'", ip);
int ret = result_rows(sql);
return ret > 0;
}
void get_mac_address(const char *ip, char *mac_address)
{
char sql[100];
sprintf(sql, "select mac from ip_mac where ip = '%s'", ip);
query(sql, fetch_mac_address); // fetch_mac_address 是一个回调函数,用于从查询结果中获取 MAC 地址
}
void fetch_next_hop_ip(MYSQL_ROW row, char (*columns)[30], int cols)
{
for (int i = 0; i < cols; i++)
{
if (strcmp(columns[i], "nexthop") == 0)
{
strcpy(next_hop_ip, row[i]);
break;
}
}
}
void fetch_gateway_ip(MYSQL_ROW row, char (*columns)[30], int cols)
{
for (int i = 0; i < cols; i++)
{
if (strcmp(columns[i], "nexthop") == 0)
{
strcpy(gateway_ip, row[i]);
break;
}
}
}
void fetch_mac_address(MYSQL_ROW row, char (*columns)[30], int cols)
{
for (int i = 0; i < cols; i++)
{
if (strcmp(columns[i], "mac") == 0)
{
strcpy(mac_address, row[i]);
break;
}
}
}