Python CGDAL类——支持栅格数据的栅格计算/线性增强/滤波增强
# -*- coding: UTF-8 -*-
'''
python version: 2.7.11
numpy ver=1.11.1
gdal ver=2.0.3
Author: Liuph
Date: 2016/9/9
Description: This is a GDAL Class adapted from the Python GDAL_OGR Cookbook documentation
(http://pcjericks.github.io/py-gdalogr-cookbook/raster_layers.html). The CGDAL Class can used to
load Image and obtain description information of the image file. Usually, basic image processing
is included in the class and a linear enhancement or a spatial filtering operation can well performed
via CGDAL Class. Moreover, this class offers some functions to generate a raster image via a numpy array.
However,it' does not process perfection in exception handling, for instance, while a image with "None"
as nodata value might lead to some puzzles, or, it wont't happen.It need to be confirmed.
Addition, as all know, codes in python formats is very concise and distinct. However, the ratio of running speed of python
codes and C++ codes is about first in thirty ,which is a serious and longstanding problem, or shortcut.
May python be better!
'''
from osgeo import gdal, gdalnumeric, ogr, osr
from PIL import Image, ImageDraw
import os, sys
from gdalconst import *
import struct
import numpy as np
import re
gdal.UseExceptions()
class CGDAL:
#数据部分
mpoDataset = None
__mpData = None
mpArray = np.array([])
mgDataType = GDT_Byte
mnRows = mnCols = mnBands = -1
mnDatalength = -1
mpGeoTransfor = []
msProjectionRef = ""
msFilename = ""
mdInvalidValue = 0.0
mnPerPixSize = 1
srcSR = None
latLongSR = None
poTransform = None
poTransformT = None
#函数部分
def __init__(self):
pass
def __del__(self):
self.mpoDataset = None
self.__mpData = None
self.mpArray = np.array([])
self.mgDataType = GDT_Byte
self.mnRows = self.mnCols = self.mnBands = -1
self.mnDatalength = -1
self.mpGeoTransform = []
self.msProjectionRef = ""
self.msFilename = ""
self.mdInvalidValue = 0.0
self.mnPerPixSize = 1
self.srcSR = None
self.latLongSR = None
self.poTransform = None
self.poTransformT = None
def read(self, band, row, col):
return self.mpArray[band, row, col]
def printimg(self):
print self.mpArray
def isValid(self):
if self.__mpData == None or self.mpoDataset == None:
return False
return True
def world2Pixel(self, lat, lon):
if self.poTransformT is not None:
CST = osr.CoordinateTransformation(self.poTransformT)
CST.TransformPoint(lon, lat)
adfInverseGeoTransform = []
x = y = 0.0
gdal.InvGeoTransform(self.mpGeoTransform, adfInverseGeoTransform)
gdal.ApplyGeoTransform(adfInverseGeoTransform, lon, lat, x, y)
return {'x': x, 'y': y}
def pixel2World(self, x, y):
if self.poTransform is not None:
self.poTransform = None
self.poTransform = osr.CoordinateTransformation(self.latLongSR, self.srcSR)
lon = lat = 0.0
gdal.ApplyGeoTransform(self.mpGeoTransform, x, y, lon, lat)
if self.poTransform is not None:
CST = osr.CoordinateTransformation(self.poTransform)
CST.TransformPoint(lon, lat)
return {'lon': lon, 'lat': lat}
def pixel2Ground(self, x, y):
pX = pY = 0.0
gdal.ApplyGeoTransform(self.mpGeoTransform, x, y, pX, pY)
return {'pX': pX, 'pY': pY}
def ground2Pixel(self, pX, pY):
x = y = 0.0
adfInverseGeoTransform = []
gdal.InvGeoTransform(self.mpGeoTransform, adfInverseGeoTransform)
gdal.ApplyGeoTransform(adfInverseGeoTransform, pX, pY, x, y)
return {'x': x, 'y': y}
def loadFrom(self,filename):
#close fore image
self.mpoDataset = None
#open image
try:
self.mpoDataset = gdal.Open( filename, GA_ReadOnly )
except RuntimeError, e:
print 'Unable to open %s' % filename
print e
return False
self.msFilename = filename
#get attribute
self.mnRows = self.mpoDataset.RasterYSize
self.mnCols = self.mpoDataset.RasterXSize
self.mnBands = self.mpoDataset.RasterCount
self.mgDataType = self.mpoDataset.GetRasterBand(1).DataType
self.mdInvalidValue = self.mpoDataset.GetRasterBand(1).GetNoDataValue()
#mapinfo
'''
GeoTransform[0] /* top left x */
GeoTransform[1] /* w-e pixel resolution */
GeoTransform[2] /* 0 */
GeoTransform[3] /* top left y */
GeoTransform[4] /* 0 */
GeoTransform[5] /* n-s pixel resolution (negative value) */
'''
self.mpGeoTransform = self.mpoDataset.GetGeoTransform()
self.msProjectionRef = self.mpoDataset.GetProjection()
self.srcSR = osr.SpatialReference(self.msProjectionRef) #ground
self.latLongSR = osr.SpatialReference()
self.latLongSR = osr.SpatialReference.CloneGeogCS(self.srcSR ) #geo
self.poTransform = osr.CoordinateTransformation(self.srcSR, self.latLongSR)
self.poTransformT = osr.CoordinateTransformation(self.latLongSR, self.srcSR)
#get data
self.msDataType = "Byte"
typeformat = "B"
if self.mgDataType == GDT_Byte:
typeformat = "B"
self.msDataType = "Byte"
elif self.mgDataType == GDT_UInt16:
typeformat = "H"
self.msDataType = "Unsigned Int 16"
elif self.mgDataType == GDT_Int16:
typeformat = "h"
self.msDataType = "Signed Int 16"
elif self.mgDataType == GDT_UInt32:
typeformat = "I"
self.msDataType = "Unsigned Int 32"
elif self.mgDataType == GDT_Int32:
typeformat = "i"
self.msDataType = "Signed Int 32"
elif self.mgDataType == GDT_Float32:
typeformat = "f"
self.msDataType = "Float 32"
elif self.mgDataType == GDT_Float64:
typeformat = "d"
self.msDataType = "Float 64"
self.__mpData = struct.unpack(typeformat*self.mnBands*self.mnCols*self.mnRows, self.mpoDataset.ReadRaster())
self.mpArray = np.array(self.__mpData)
self.mpArray.shape = (self.mnBands, self.mnRows, self.mnCols)
return True
def getRasterBand(self, band_num):
"""获取特定波段的数据
"""
try:
srcband = self.mpoDataset.GetRasterBand(band_num)
return srcband
except RuntimeError, e:
print 'Band ( %i ) not found' % band_num
print e
sys.exit(0)
def getRasterBand2Array(self, band_num):
"""获取特定波段的数据,存储为数组"""
srcband = self.mpoDataset.GetRasterBand(band_num)
return srcband.ReadAsArray()
def getRasterBandStas(self, band_num):
"""获取特定波段的统计量(最小值,最大值,均值,标准差)"""
srcband = self.mpoDataset.GetRasterBand(band_num)
if srcband is None:
print "Band %i is NULL" % band_num
sys.exit(1)
stats = srcband.GetStatistics(True, True)
if stats is None:
print "Statistics of Band %i is NULL" % band_num
sys.exit(1)
print "[ STATS ] = Minimum=%.3f, Maximum=%.3f, Mean=%.3f, StdDev=%.3f" % (
stats[0], stats[1], stats[2], stats[3])
def getRasterBandInfo(self, band_num):
"""获取特定波段的描述数据"""
srcband = self.mpoDataset.GetRasterBand(band_num)
if srcband is None:
print "Band %i is NULL" % band_num
sys.exit(1)
print "[ NO DATA VALUE ] = ", srcband.GetNoDataValue()
print "[ MIN ] = ", srcband.GetMinimum()
print "[ MAX ] = ", srcband.GetMaximum()
print "[ SCALE ] = ", srcband.GetScale()
print "[ UNIT TYPE ] = ", srcband.GetUnitType()
ctable = srcband.GetColorTable()
if ctable is None:
print 'No ColorTable found'
sys.exit(1)
print "[ COLOR TABLE COUNT ] = ", ctable.GetCount()
for i in range(0, ctable.GetCount()):
entry = ctable.GetColorEntry(i)
if not entry:
continue
print "[ COLOR ENTRY RGB ] = ", ctable.GetColorEntryAsRGB(i, entry)
def getRasterBandMinVal(self, band_num):
"""获取某个波段的最小值"""
_arr = self.mpArray[band_num-1,:,:]
if self.mdInvalidValue != None:
_arr[_arr == self.mdInvalidValue] = np.nan
return np.nanmin(_arr)
def getRasterBandMaxVal(self, band_num):
"""由于精度问题,显示一位小数,但计算不出错"""
_arr = self.mpArray[band_num - 1, :,:]
if self.mdInvalidValue != None:
_arr[_arr == self.mdInvalidValue] = np.nan
return np.nanmax(_arr)
def getRasterBandMeanVal(self, band_num):
"""均值"""
_arr = self.mpArray[band_num - 1, :,:]
if self.mdInvalidValue != None:
_arr[_arr == self.mdInvalidValue] = np.nan
return np.nanmean(_arr)
def getRasterBandStdVal(self, band_num):
"""标准差"""
_arr = self.mpArray[band_num - 1, :,:]
if self.mdInvalidValue != None:
_arr[_arr == self.mdInvalidValue] = np.nan
return np.nanstd(_arr)
def getRasterBandVarVal(self, band_num):
"""方差"""
_arr = self.mpArray[band_num - 1, :,:]
if self.mdInvalidValue != None:
_arr[_arr == self.mdInvalidValue] = np.nan
return np.nanvar(_arr)
def raster2shp(self, band_num, dst_layername):
"""栅格转矢量,慎用"""
srcband = self.mpoDataset.GetRasterBand(band_num)
drv = ogr.GetDriverByName("ESRI Shapefile")
dst_ds = drv.CreateDataSource(dst_layername + ".shp")
dst_layer = dst_ds.CreateLayer(dst_layername, srs=None)
gdal.Polygonize(srcband, None, dst_layer, -1, [], callback=None)
def replaceNoData2New(self, ds_fn, new_NoData):
"""用新的值替代原先的nodata值"""
outArr = np.zeros(self.mnBands * self.mnRows * self.mnCols)
outArr.shape = (self.mnBands, self.mnRows, self.mnCols)
for band_num in range(1, self.mnBands + 1):
self.mpoDataset.GetRasterBand(band_num).SetNoDataValue(-9999)
org_Nodata = -9999
rasterArray = self.getRasterBand2Array(band_num)
rasterArray[rasterArray == org_Nodata] = new_NoData
outArr[band_num - 1, :, :] = rasterArray
array2MultiBandsrasterfn(self.msFilename, ds_fn, outArr, self.mnBands)
def linearEnhance(self, ds_fn, _MinValue, _MaxValue):
"""线性增强处理,指定拉伸后的最大最小值,float64型"""
outArr = np.zeros(self.mnBands * self.mnRows * self.mnCols)
outArr.shape = (self.mnBands, self.mnRows, self.mnCols)
for band_num in range(1, self.mnBands + 1):
print "Linear Cal %i/%i"%(band_num, self.mnBands)
srcband = self.mpoDataset.GetRasterBand(band_num)
_nodata = srcband.GetNoDataValue()
_array = self.getRasterBand2Array(band_num)
_newarray = _array.astype(np.float32)
_min = self.getRasterBandMinVal(band_num)
_max = self.getRasterBandMaxVal(band_num)
#print _min, _max
for i in range(self.mnRows):
for j in range(self.mnCols):
if _array[i][j] >= _min and _array[i][j] <= _max:
_newarray[i][j] = (_array[i][j] - _min) / ((_max - _min) * 1.0) * (
_MaxValue - _MinValue) + _MinValue
else:
_newarray[i][j] = _nodata
outArr[band_num - 1, :, :] = _newarray
print "Writing output data..."
array2MultiBandsrasterfn(self.msFilename, ds_fn, outArr, self.mnBands, self.mdInvalidValue)
def spatialFiltering(self, ds_fn, sAlgorithm = "MeanFiltering"):
"""空间滤波增强"""
window_size = 3
if window_size%2 == 0:
print "Please input a uneven number for the window size!"
sys.exit(1)
subsize = (window_size-1)/2
#输出文件
outArr = np.zeros(self.mnBands * self.mnRows * self.mnCols)
outArr.shape = (self.mnBands, self.mnRows, self.mnCols)
algori = np.ones(1 * window_size * window_size, dtype=float)
algori.shape = (window_size, window_size)
# 选择算子
if sAlgorithm == "MeanFiltering":
algori /= (window_size*window_size)
elif sAlgorithm == "LaplaceFiltering":
algori = np.array([[-1.0,-1.0,-1.0],[-1.0,9,-1],[-1,-1,-1]])
elif sAlgorithm == "WallisFiltering":
algori = np.array([[0,-0.25,0],[-0.25,1,-0.25],[0,-0.25,0]])
elif sAlgorithm == "SobelXFiltering":
algori = np.array([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]])
elif sAlgorithm == "SobelYFiltering":
algori = np.array([[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]])
elif sAlgorithm == "LogFiltering":
window_size = 5
subsize = (window_size - 1) / 2
algori = np.ones(1 * window_size * window_size, dtype=float)
algori.shape = (window_size, window_size)
algori = np.array([[-2.,-4.,-4.,-4.,-2.],
[-4.,0.,8.,0.,-4.],
[-4.,8.,24.,8.,-4.],
[-4., 0., 8., 0., -4.],
[-2., -4., -4., -4., -2.]
])
elif sAlgorithm == "RelievoFiltering":
algori = np.array([[-3.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]])
elif sAlgorithm == "HorizonalMaskFiltering":
algori = np.array([[3.0, 3.0, 3.0], [-6.0, -6.0, -6.0], [3.0, 3.0, 3.0]])
elif sAlgorithm == "VerticalnMaskFiltering":
algori = np.array([[3.0, -6.0, 3.0], [3.0, -6.0, 3.0], [3.0, -6.0, 3.0]])
elif sAlgorithm == "DiagonalMaskFiltering":
algori = np.array([[3.0, 3.0, -6.0], [3.0, -6.0, 3.0], [-6.0, 3.0, 3.0]])
elif sAlgorithm== "QualcommEdgeDec":
algori = np.array([[-1.0, 0.0, -1.0], [0.0, 4.0, 0.0], [-1.0, 0.0, 1.0]])
else:
print "There is no such filtering algorithm called %s"%sAlgorithm
sys.exit(1)
print "Filtering Algorithm: \n", algori
#波段迭代循环
for band_num in range(1, self.mnBands +1):
_arr = np.zeros(1 * self.mnRows * self.mnCols, dtype=float)
_arr.shape = (self.mnRows, self.mnCols)
for i in range(0, self.mnRows):
for j in range(0, self.mnCols):
#边缘维持原像元值
if i<=subsize-1 or j<=subsize-1 or i>= self.mnRows-subsize or j >= self.mnCols-subsize:
_arr[i][j] = self.mpArray[band_num-1][i][j]
else:
for x in range(0, window_size ):
for y in range(0, window_size):
_arr[i][j] += self.mpArray[band_num-1][i - subsize + x][j - subsize + y] * algori[x][y]
outArr[band_num - 1, :, :] = _arr
print "Filtered %i/%i"%(band_num, self.mnBands)
print "Writing output file..."
array2MultiBandsrasterfn(self.msFilename, ds_fn, outArr, self.mnBands,self.mdInvalidValue)
def rasterCalculation(self, ds_fn, expr = "(band3-band2)/(band3+band2)"):
"""栅格计算,暂时只支持用band1,2之类的形式表示各个波段"""
mode = re.compile(r'\d+')
m = mode.findall(expr)
nums = np.unique(np.array(m))
sortedNums = np.sort(nums)
for num in sortedNums:
expr = expr.replace(num, str(int(num)-1)+',:,:]')
expr = expr.replace('band','1.0*self.mpArray[')
print expr
resultArr = eval(expr)
array2MultiBandsrasterfn(self.msFilename,ds_fn,resultArr,1,self.mdInvalidValue)
def printRasterAttr(self):
"""显示图像信息"""
print "File Name: %s"%self.msFilename
print "Rows: %i Cols: %i Bands: %i Pixel Size: %.2f*%.2f"%(self.mnRows, self.mnCols,
self.mnBands, self.mpGeoTransform[1],-self.mpGeoTransform[5])
print "Data Type: %s No-Data Value: "%(self.msDataType),self.mdInvalidValue
print "SpatialRef: %s \nProjection: %s"%(self.mpGeoTransform, self.msProjectionRef)
def array2MultiBandsrasterfn(rasterfn, newRasterfn, array, bandCount, nodata = None):
"""文件尺度上数组生成栅格文件,前者栅格文件提供描述信息(多波段)"""
raster = gdal.Open(rasterfn)
geotransform = raster.GetGeoTransform()
originX = geotransform[0]
originY = geotransform[3]
pixelWidth = geotransform[1]
pixelHeight = geotransform[5]
cols = raster.RasterXSize
rows = raster.RasterYSize
array.shape = (bandCount, rows, cols)
driver = gdal.GetDriverByName('GTiff')
outRaster = driver.Create(newRasterfn, cols, rows, bandCount, gdal.GDT_Float32)
outRaster.SetGeoTransform((originX, pixelWidth, 0, originY, 0, pixelHeight))
outRasterSRS = osr.SpatialReference()
outRasterSRS.ImportFromWkt(raster.GetProjectionRef())
outRaster.SetProjection(outRasterSRS.ExportToWkt())
for band_num in range(1, bandCount + 1):
outband = outRaster.GetRasterBand(band_num)
outband.SetNoDataValue(nodata)
outband.WriteArray(array[band_num - 1, :, :])
outband.FlushCache()
print "write output file -- %s success!"%newRasterfn
def creatraster(newRasterfn, GeoTransform, projection, datatype, imgdata, cols, rows, bands):
#必须使用numpy下的numpy.array作为imgdata
if bands == 1:
imgdata.shape = (bands, rows, cols)
driver = gdal.GetDriverByName('GTiff')
outRaster = driver.Create(newRasterfn, cols, rows, bands, datatype)
outRaster.SetGeoTransform(GeoTransform)
outRaster.SetProjection(projection)
for i in range(bands):
array = imgdata[i, :, :]
outband = outRaster.GetRasterBand(i+1)
outband.WriteArray(array)
print "write data succeed!"
def rasterCalculations(ds_fn, expr):
"""仅支持tif格式",表达式中要写文件后缀"""
m = re.findall(r'([a-z,A-Z,_]+[1-9,a-z,A-Z,_]*.tif)', expr)
unim = np.unique(np.array(m))
print unim
i = 0
mArrs = []
for item in unim:
Cgdal = CGDAL()
Cgdal.loadFrom(item)
if i ==0:
no_data = Cgdal.mdInvalidValue
if Cgdal.mnBands != 1:
print "The input raster is not useful. Only 1 band is required, %i is given."%Cgdal.mnBands
sys.exit(1)
Cgdal.mpArray[Cgdal.mpArray == no_data] = np.nan
mArrs.append(Cgdal.mpArray)
expr = expr.replace(item,'1.0*mArrs[%i]'%i)
i = i + 1
print expr
resultArr = eval(expr)
array2MultiBandsrasterfn(unim[0],ds_fn,resultArr,1,nodata=no_data)
转载自:https://blog.csdn.net/liuph_/article/details/52491123