HW1主要是使用liner model 进行pm2.5的预测
作业连接:https://ntumlta.github.io/2017fall-ml-hw1/
由于部分ppf被强,所以这里加上一个百度网盘连接:
https://pan.baidu.com/s/1Ff-3zdzqMEi1W2qUf3Agdg 密码 ooqn
内容是这个作业的相关内容
作业要求:
1. 使用前9个小时的数据,预测出第十个小时的PM2.5的值是多少
2.提供2014年的12个月每个月的前20天的24小时数据作为train data
3.每小时有18组数据(so2 甲烷 之类的指标)
下面解析 作业的sample code
import xlrd import numpy as np data = for i in range( 18): data.append([]) #18 组数据 if __name__ == '__main__': iFileDir = "./"; iFileName = iFileDir + "train.xlsx"; print('iFileName = %r'%iFileName) try: wb = xlrd.open_workbook(iFileName) except: print( "file %s is not exist" % (iFileName) ) for sheet_name in wb.sheet_names(): sheet = wb.sheet_by_name(sheet_name) for row_num in range(1, sheet.nrows): for i in range(3, 27): #3-27 是对应的24小时数据 data_tmp = sheet.cell_value(row_num, i) #将数据转换成浮点数 if data_tmp == 'NR': #NR是没有检测到数据 data_tmp = float(0) data[(row_num-1)].append(float(data_tmp)) x = [] y = []上面只是就相应的train数据吃进18个list
month = 12 #10hour as one data data_len = 9 # 连续9个小时的数据作为输入 data_length = 20 * 24 # 一个月24天 每天20个数据 for i in range(12): # 12 个月 for j in range(480 - 9): # x.append([]) # for t in range(18): # 18组数据 for k in range(9): # 连续9个数据作为一个input x[471*i +j].append(data[t][480*i + j+k]) y.append(data[9][480*i+j+9]) # 第9个list存放的是pm2.5的数据 x = np.array(x) #得到x 每一维度是 18 * 9 个数据, 每个月会有471个维度 一共有471*12个维度 y = np.array(y)y = b + w10*x10 + w20*x20 + ...... w90*x90
+ w11 *x11 +...............................+w91*x91
......
+ w117*x117..............................+w917*x917
x = np.concatenate((np.ones((x.shape[0],1)),x),axis=1) #每个维度加一个数据1 作为bias w = np.zeros(len(x[0])) #选定一个起始点,这里做了维度为1 长度为18 *9 +1的 数值为0的矩阵 l_rate = 10 #初始learning rate repeat = 10000补 Loss function 的公式
x_t = x.transpose() s_gra = np.zeros(len(x[0])) for i in range(repeat): #重复计算10000次 hypo = np.dot(x,w) #得到一个一维矩阵 每个值对应一个y` loss = hypo - y cost = np.sum(loss**2) / len(x) cost_a = math.sqrt(cost) gra = np.dot(x_t,loss) #这里不懂,应该是矩阵的导数 s_gra += gra**2 ada = np.sqrt(s_gra) w = w - l_rate * gra/ada #通过adagrad 更新初始点 print ('iteration: %d | Cost: %f ' % ( i,cost_a))