ldcast_code / scripts /evaluate_res.py
weatherforecast1024's picture
Upload folder using huggingface_hub
d2f661a verified
import os
import numpy as np
def anomaly_correlation_coefficient(ground_truth, predictions):
# Tính toán mean của ground truth và predictions
num_variables = ground_truth.shape[1]
acc_values = []
for i in range(num_variables):
observed_var = ground_truth[:, i]
predicted_var = predictions[:, i]
mean_observed = np.mean(observed_var)
mean_predicted = np.mean(predicted_var)
covariance = np.mean((observed_var - mean_observed) * (predicted_var - mean_predicted))
std_observed = np.std(observed_var)
std_predicted = np.std(predicted_var)
acc = covariance / (std_observed * std_predicted)
acc_values.append(acc)
mean_acc = np.mean(acc_values)
return mean_acc
def root_mean_square_error(ground_truth, predictions):
mse = np.mean((ground_truth - predictions)**2)
rmse = np.sqrt(mse)
return rmse
def weighted_mean_square_error(ground_truth, predictions, weights):
mse = np.mean(weights * (ground_truth - predictions)**2)
return mse
def root_weighted_mean_square_error(ground_truth, predictions, weights):
wmse = weighted_mean_square_error(ground_truth, predictions, weights)
rwmse = np.sqrt(wmse)
return rwmse
def crps(predicted_probs, observed_value):
# Sắp xếp các xác suất dự đoán theo thứ tự tăng dần
predicted_probs_sorted = np.sort(predicted_probs)
# Tính toán số lượng điểm xác suất
n = len(predicted_probs)
# Tính toán CRPS bằng cách sử dụng công thức CRPS diskret
crps = 0
for i in range(n):
crps += (predicted_probs_sorted[i] - (i + 0.5) / n)**2
crps *= 1 / n
# Thêm thành phần phân phối Gaussian
crps += np.mean((predicted_probs - observed_value)**2)
return crps
# Đường dẫn đến thư mục chứa file ground truth và predictions
#folder_ground_truth = '/data/data_WF/finetune/output/exp_GT'
#folder_predictions = '/data/data_WF/finetune/output/exp_pred'
folder_ground_truth = '/data/data_WF/finetune/output/x_GT'
folder_predictions = '/data/data_WF/finetune/output/x_pred'
#folder_ground_truth = '/data/data_WF/ldcast_precipitation/train_ground_truth'
#folder_predictions = '/data/data_WF/ldcast_precipitation/train'
folder_weights = ''
# Danh sách các tên file trong thư mục ground truth và predictions
gt_files = os.listdir(folder_ground_truth)
pred_files = os.listdir(folder_predictions)
# weight_files = os.listdir(folder_weights)
# Khởi tạo một danh sách để lưu trữ các giá trị MSE
mse_list = []
acc_list = []
rmse_list = []
# wmse_list = []
# rwmse_list = []
crps_list = []
# Vòng lặp qua các file trong thư mục
for gt_file, pred_file in zip(gt_files, pred_files):
# Đọc dữ liệu từ file ground truth và file dự đoán
ground_truth = np.load(os.path.join(folder_ground_truth, gt_file))
predictions = np.load(os.path.join(folder_predictions, pred_file))
ground_truth = ground_truth * 1000
predictions = predictions * 1000
mse = np.mean((ground_truth - predictions)**2)
acc = anomaly_correlation_coefficient(ground_truth, predictions)
rmse = root_mean_square_error(ground_truth, predictions)
crps_score = crps(predictions, ground_truth)
# wmse = weighted_mean_square_error(ground_truth, predictions, weights)
# rwmse = root_weighted_mean_square_error(ground_truth, predictions, weights)
# Thêm MSE vào danh sách
mse_list.append(mse)
acc_list.append(acc)
rmse_list.append(rmse)
# wmse_list.append(wmse)
# rwmse_list.append(rwmse)
crps_list.append(crps_score)
# Tính toán trung bình của các giá trị MSE
average_mse = np.mean(mse_list)
average_acc = np.mean(acc_list)
average_rmse = np.mean(rmse_list)
# average_wmse = np.mean(wmse_list)
# average_rwmse = np.mean(rwmse_list)
average_crps = np.mean(crps_list)
print("Average Mean Square Error (MSE) across all files:", average_mse)
print("Average Anomaly Correlation Coefficient (ACC) across all files:", average_acc)
print("Average Root Mean Square Error (RMSE) across all files:", average_rmse)
# print("Average Weighted Mean Square Error (WMSE) across all files:", average_wmse)
# print("Average Root Weighted Mean Square Error (RWMSE) across all files:", average_rwmse)
print("Average Continuous Ranked Probability Score (CRPS) across all files:", average_crps)