c-router-emulator/router/db.c

231 lines
5.3 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "db.h"
#include <stdio.h>
int connect_mysql(const char *host, int port, const char *user, const char *pwd, const char *db_name)
{
if (conn_db == NULL)
{
conn_db = (MYSQL *)malloc(sizeof(MYSQL));
if (mysql_init(conn_db) != NULL)
{
mysql_set_character_set(conn_db, "utf8");
// if (mysql_real_connect(conn_db, "localhost", "disen", "disen", "db1",
// 3306, NULL, 0) == NULL)
if (mysql_real_connect(conn_db, host, user, pwd, db_name,
port, NULL, 0) == NULL)
{
return -1;
}
}
else
{
return -1;
}
}
return 0;
}
int query(const char *sql, void (*callback)(MYSQL_ROW row, char (*columns)[30], int cols))
{
if (NULL == conn_db)
{
return -1;
}
int ret = mysql_real_query(conn_db, sql, strlen(sql));
if (ret != 0)
{
return -1;
}
MYSQL_RES *res = mysql_store_result(conn_db);
if (res != NULL)
{
my_ulonglong rows = mysql_num_rows(res);
// printf("查找到的数据库的行数为 %llu\n", rows);
unsigned int cols = mysql_num_fields(res);
char colunm_names[cols][30];
int i = 0;
MYSQL_FIELD *field = NULL;
while ((field = mysql_fetch_field(res)) != NULL)
{
strcpy(colunm_names[i++], field->name);
}
for (int i = 0; i < rows; i++)
{
MYSQL_ROW row = mysql_fetch_row(res);
if (callback != NULL)
callback(row, colunm_names, cols);
}
mysql_free_result(res);
}
return 0;
}
int insert(const char *sql, MYSQL_BIND *params)
{
MYSQL_STMT *stmt = mysql_stmt_init(conn_db);
mysql_stmt_prepare(stmt, sql, strlen(sql));
if (params != NULL)
{
if (mysql_stmt_bind_param(stmt, params)) // 11错误 0成功
{
mysql_stmt_close(stmt);
return -1;
}
}
if (mysql_stmt_execute(stmt) != 0)
{
mysql_stmt_close(stmt);
return -1;
}
int ret = mysql_stmt_fetch(stmt);
mysql_stmt_close(stmt);
return ret; // 返回是影响的行数
}
int delete (const char *sql, MYSQL_BIND *params)
{
MYSQL_STMT *stmt = mysql_stmt_init(conn_db);
mysql_stmt_prepare(stmt, sql, strlen(sql));
if (params != NULL)
{
if (mysql_stmt_bind_param(stmt, params)) // 11错误 0成功
{
mysql_stmt_close(stmt);
return -1;
}
}
if (mysql_stmt_execute(stmt) != 0)
{
mysql_stmt_close(stmt);
return -1;
}
int ret = mysql_stmt_fetch(stmt);
mysql_stmt_close(stmt);
return ret; // 返回是影响的行数
}
int close_mysql()
{
if (conn_db != NULL)
{
mysql_close(conn_db);
return 0;
}
return -1;
}
int result_rows(const char *sql)
{
if (NULL == conn_db)
{
return -1;
}
int ret = mysql_real_query(conn_db, sql, strlen(sql));
if (ret != 0)
{
return -1;
}
my_ulonglong rows = 0;
MYSQL_RES *res = mysql_store_result(conn_db);
if (res != NULL)
{
rows = mysql_num_rows(res);
// printf("查找到的数据库的行数为 %llu\n", rows);
}
return rows;
}
void insert_routing_list(const char *ip, const char *mask, const char *nexthop)
{
char sql[100];
sprintf(sql, "insert into routing_list(ip, mask, nexthop) values('%s', '%s', '%s')", ip, mask, nexthop);
if (mysql_real_query(conn_db, sql, strlen(sql)) != 0)
{
printf("添加路由表失败\n");
}
}
void delete_routing_list(const char *ip)
{
char sql[100];
sprintf(sql, "delete from routing_list where ip = '%s'", ip);
if (mysql_real_query(conn_db, sql, strlen(sql)) != 0)
{
printf("删除路由表失败\n");
}
}
void print_routing_list()
{
char sql[100];
sprintf(sql, "select * from routing_list");
query(sql, printResult); // 打印查询结果
}
void insert_arp_list(const char *ip, const char *mac)
{
char sql[100];
sprintf(sql, "insert into ip_mac(ip, mac) values('%s', '%s')", ip, mac);
if (mysql_real_query(conn_db, sql, strlen(sql)) != 0)
{
printf("添加ARP表失败\n");
}
}
void print_arp_list()
{
char sql[100];
sprintf(sql, "select * from ip_mac");
query(sql, printResult); // 打印查询结果
}
void printResult(MYSQL_ROW row, char (*columns)[30], int cols)
{
for (int i = 0; i < cols; i++)
{
printf("%s: %s \t", columns[i], row[i]);
}
printf("\n");
}
void insert_ip_fw(const char *ip)
{
char sql[100];
sprintf(sql, "insert into ip_fw(ip) values('%s')", ip);
if (mysql_real_query(conn_db, sql, strlen(sql)) != 0)
{
printf("添加黑名单失败\n");
}
}
int search_ip_fw(const char *ip)
{
char sql[100];
sprintf(sql, "select * from ip_fw where ip = '%s'", ip);
int ret = result_rows(sql);
return ret;
}
void print_ip_fw()
{
char sql[100];
sprintf(sql, "select * from ip_fw");
query(sql, printResult); // 打印查询结果
}
void delete_ip_fw(const char *ip)
{
char sql[100];
sprintf(sql, "delete from ip_fw where ip = '%s'", ip);
if (mysql_real_query(conn_db, sql, strlen(sql)) != 0)
{
printf("删除黑名单失败\n");
}
}