这篇教程python实现矩阵的示例代码写得很实用,希望能帮到您。
矩阵使用python构建一个类,模拟矩阵,可以进行各种矩阵的计算,与各种方便的用法
initfrom array import arrayclass Matrix: def __init__(self, matrix: 'a list of one dimension', shape: 'a tuple of shape' = None, dtype: 'data type code' = 'd'): # matrix一个包含所有元素的列表,shape指定形状,默认为列向量,dtype是数据类型 # 使用一个数组模拟矩阵,通过操作这个数组完成矩阵的运算 self.shape = (len(matrix), 1) if shape: self.shape = shape self.array = array(dtype, matrix)
getitem由于矩阵是一个二维数组,应当支持诸如matrix[1, 2],matrix[1:3, 2],matrix[1:3, 2:4]之类的取值 所以我们需要使用slice类的indice方法实现__getitem__,并支持切片 def __getitem__(self, item: 'a index of two dimensions'): # 使用slice类的indices方法,实现二维切片 rows, cols = item # 下面对传入的指针或者切片进行处理,使其可以统一处理 if isinstance(rows, slice): rows = rows.indices(self.shape[0]) else: rows = (rows, rows + 1) if isinstance(cols, slice): cols = cols.indices(self.shape[1]) else: cols = (cols, cols + 1) res = [] shape = (len(range(*rows)), len(range(*cols))) # 新矩阵的形状 # 通过遍历按照顺序将元素加入新的矩阵 for row in range(*rows): for col in range(*cols): index = row * self.shape[1] + col res.append(self.array[index]) if len(res) == 1: # 若是数则返回数 return res[0] # 若是矩阵则返回矩阵 return Matrix(res, shape) 由于要支持切片,所以需要在方法中新建一个矩阵用于返回值
setitem由于没有序列协议的支持,我们需要自己实现__setitem__ def __setitem__(self, key: 'a index or slice of two dimensions', value): # 使用slice类的indices方法,实现二维切片的广播赋值 rows, cols = key if isinstance(rows, slice): rows = rows.indices(self.shape[0]) else: rows = (rows, rows + 1) if isinstance(cols, slice): cols = cols.indices(self.shape[1]) else: cols = (cols, cols + 1) if isinstance(value, Matrix): # 对于传入的值是矩阵,则需要判断形状 if value.shape != (len(range(*rows)), len(range(*cols))): raise ShapeError # 使用x,y指针取出value中的值赋给矩阵 x = -1 for row in range(*rows): x += 1 y = -1 for col in range(*cols): y += 1 index = row * self.shape[1] + col self.array[index] = value[x, y] else: for row in range(*rows): for col in range(*cols): index = row * self.shape[1] + col self.array[index] = value 若传入的value是一个数,这里的逻辑基本与__getitem__相同,实现了广播。 而若传入的是一个矩阵,则需要判断形状,对对应的元素进行赋值,这是为了方便LU分解。
reshapereshape用于改变形状,对于上面的实现方法,只需要改变matrix.shape就可以了 注意改变前后的总元素数应当一致 def reshape(self, shape: 'a tuple of shape'): if self.shape[0] * self.shape[1] != shape[0] * shape[1]: raise ShapeError self.shape = shape
repr实现__repr__方法,较为美观的打印矩阵 def __repr__(self): shape = self.shape _array = self.array return "[" + ",/n".join(str(list(_array[i * shape[1]:(i + 1) * shape[1]])) for i in range(shape[0])) + "]"
add 与 mul对于加法与乘法的支持,这里的乘法是元素的乘法,不是矩阵的乘法 同样的,实现广播 |