TRAAPassNode.js 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. import { Color, Vector2, NearestFilter, Matrix4, RendererUtils, PassNode, QuadMesh, NodeMaterial } from 'three/webgpu';
  2. import { add, float, If, Loop, int, Fn, min, max, clamp, nodeObject, texture, uniform, uv, vec2, vec4, luminance } from 'three/tsl';
  3. const _quadMesh = /*@__PURE__*/ new QuadMesh();
  4. const _size = /*@__PURE__*/ new Vector2();
  5. let _rendererState;
  6. /**
  7. * A special render pass node that renders the scene with TRAA (Temporal Reprojection Anti-Aliasing).
  8. *
  9. * Note: The current implementation does not yet support MRT setups.
  10. *
  11. * References:
  12. * - {@link https://alextardif.com/TAA.html}
  13. * - {@link https://www.elopezr.com/temporal-aa-and-the-quest-for-the-holy-trail/}
  14. *
  15. * @augments PassNode
  16. * @three_import import { traaPass } from 'three/addons/tsl/display/TRAAPassNode.js';
  17. */
  18. class TRAAPassNode extends PassNode {
  19. static get type() {
  20. return 'TRAAPassNode';
  21. }
  22. /**
  23. * Constructs a new TRAA pass node.
  24. *
  25. * @param {Scene} scene - The scene to render.
  26. * @param {Camera} camera - The camera to render the scene with.
  27. */
  28. constructor( scene, camera ) {
  29. super( PassNode.COLOR, scene, camera );
  30. /**
  31. * This flag can be used for type testing.
  32. *
  33. * @type {boolean}
  34. * @readonly
  35. * @default true
  36. */
  37. this.isTRAAPassNode = true;
  38. /**
  39. * The clear color of the pass.
  40. *
  41. * @type {Color}
  42. * @default 0x000000
  43. */
  44. this.clearColor = new Color( 0x000000 );
  45. /**
  46. * The clear alpha of the pass.
  47. *
  48. * @type {number}
  49. * @default 0
  50. */
  51. this.clearAlpha = 0;
  52. /**
  53. * The jitter index selects the current camera offset value.
  54. *
  55. * @private
  56. * @type {number}
  57. * @default 0
  58. */
  59. this._jitterIndex = 0;
  60. /**
  61. * Used to save the original/unjittered projection matrix.
  62. *
  63. * @private
  64. * @type {Matrix4}
  65. */
  66. this._originalProjectionMatrix = new Matrix4();
  67. /**
  68. * A uniform node holding the inverse resolution value.
  69. *
  70. * @private
  71. * @type {UniformNode<vec2>}
  72. */
  73. this._invSize = uniform( new Vector2() );
  74. /**
  75. * The render target that holds the current sample.
  76. *
  77. * @private
  78. * @type {?RenderTarget}
  79. * @default null
  80. */
  81. this._sampleRenderTarget = null;
  82. /**
  83. * The render target that represents the history of frame data.
  84. *
  85. * @private
  86. * @type {?RenderTarget}
  87. * @default null
  88. */
  89. this._historyRenderTarget = null;
  90. /**
  91. * Material used for the resolve step.
  92. *
  93. * @private
  94. * @type {NodeMaterial}
  95. */
  96. this._resolveMaterial = new NodeMaterial();
  97. this._resolveMaterial.name = 'TRAA.Resolve';
  98. }
  99. /**
  100. * Sets the size of the effect.
  101. *
  102. * @param {number} width - The width of the effect.
  103. * @param {number} height - The height of the effect.
  104. * @return {boolean} Whether the TRAA needs a restart or not. That is required after a resize since buffer data with different sizes can't be resolved.
  105. */
  106. setSize( width, height ) {
  107. super.setSize( width, height );
  108. let needsRestart = false;
  109. if ( this.renderTarget.width !== this._sampleRenderTarget.width || this.renderTarget.height !== this._sampleRenderTarget.height ) {
  110. this._sampleRenderTarget.setSize( this.renderTarget.width, this.renderTarget.height );
  111. this._historyRenderTarget.setSize( this.renderTarget.width, this.renderTarget.height );
  112. this._invSize.value.set( 1 / this.renderTarget.width, 1 / this.renderTarget.height );
  113. needsRestart = true;
  114. }
  115. return needsRestart;
  116. }
  117. /**
  118. * This method is used to render the effect once per frame.
  119. *
  120. * @param {NodeFrame} frame - The current node frame.
  121. */
  122. updateBefore( frame ) {
  123. const { renderer } = frame;
  124. const { scene, camera } = this;
  125. _rendererState = RendererUtils.resetRendererState( renderer, _rendererState );
  126. //
  127. this._pixelRatio = renderer.getPixelRatio();
  128. const size = renderer.getSize( _size );
  129. const needsRestart = this.setSize( size.width, size.height );
  130. // save original/unjittered projection matrix for velocity pass
  131. camera.updateProjectionMatrix();
  132. this._originalProjectionMatrix.copy( camera.projectionMatrix );
  133. // camera configuration
  134. this._cameraNear.value = camera.near;
  135. this._cameraFar.value = camera.far;
  136. // configure jitter as view offset
  137. const viewOffset = {
  138. fullWidth: this.renderTarget.width,
  139. fullHeight: this.renderTarget.height,
  140. offsetX: 0,
  141. offsetY: 0,
  142. width: this.renderTarget.width,
  143. height: this.renderTarget.height
  144. };
  145. const originalViewOffset = Object.assign( {}, camera.view );
  146. if ( originalViewOffset.enabled ) Object.assign( viewOffset, originalViewOffset );
  147. const jitterOffset = _JitterVectors[ this._jitterIndex ];
  148. camera.setViewOffset(
  149. viewOffset.fullWidth, viewOffset.fullHeight,
  150. viewOffset.offsetX + jitterOffset[ 0 ] * 0.0625, viewOffset.offsetY + jitterOffset[ 1 ] * 0.0625, // 0.0625 = 1 / 16
  151. viewOffset.width, viewOffset.height
  152. );
  153. // configure velocity
  154. const mrt = this.getMRT();
  155. const velocityOutput = mrt.get( 'velocity' );
  156. if ( velocityOutput !== undefined ) {
  157. velocityOutput.setProjectionMatrix( this._originalProjectionMatrix );
  158. } else {
  159. throw new Error( 'THREE:TRAAPassNode: Missing velocity output in MRT configuration.' );
  160. }
  161. // render sample
  162. renderer.setMRT( mrt );
  163. renderer.setClearColor( this.clearColor, this.clearAlpha );
  164. renderer.setRenderTarget( this._sampleRenderTarget );
  165. renderer.render( scene, camera );
  166. renderer.setRenderTarget( null );
  167. renderer.setMRT( null );
  168. // every time when the dimensions change we need fresh history data. Copy the sample
  169. // into the history and final render target (no AA happens at that point).
  170. if ( needsRestart === true ) {
  171. // bind and clear render target to make sure they are initialized after the resize which triggers a dispose()
  172. renderer.setRenderTarget( this._historyRenderTarget );
  173. renderer.clear();
  174. renderer.setRenderTarget( this.renderTarget );
  175. renderer.clear();
  176. renderer.setRenderTarget( null );
  177. renderer.copyTextureToTexture( this._sampleRenderTarget.texture, this._historyRenderTarget.texture );
  178. renderer.copyTextureToTexture( this._sampleRenderTarget.texture, this.renderTarget.texture );
  179. } else {
  180. // resolve
  181. renderer.setRenderTarget( this.renderTarget );
  182. _quadMesh.material = this._resolveMaterial;
  183. _quadMesh.render( renderer );
  184. renderer.setRenderTarget( null );
  185. // update history
  186. renderer.copyTextureToTexture( this.renderTarget.texture, this._historyRenderTarget.texture );
  187. }
  188. // copy depth
  189. renderer.copyTextureToTexture( this._sampleRenderTarget.depthTexture, this.renderTarget.depthTexture );
  190. // update jitter index
  191. this._jitterIndex ++;
  192. this._jitterIndex = this._jitterIndex % ( _JitterVectors.length - 1 );
  193. // restore
  194. if ( originalViewOffset.enabled ) {
  195. camera.setViewOffset(
  196. originalViewOffset.fullWidth, originalViewOffset.fullHeight,
  197. originalViewOffset.offsetX, originalViewOffset.offsetY,
  198. originalViewOffset.width, originalViewOffset.height
  199. );
  200. } else {
  201. camera.clearViewOffset();
  202. }
  203. velocityOutput.setProjectionMatrix( null );
  204. RendererUtils.restoreRendererState( renderer, _rendererState );
  205. }
  206. /**
  207. * This method is used to setup the effect's render targets and TSL code.
  208. *
  209. * @param {NodeBuilder} builder - The current node builder.
  210. * @return {PassTextureNode}
  211. */
  212. setup( builder ) {
  213. if ( this._sampleRenderTarget === null ) {
  214. this._sampleRenderTarget = this.renderTarget.clone();
  215. this._historyRenderTarget = this.renderTarget.clone();
  216. this._sampleRenderTarget.texture.minFiler = NearestFilter;
  217. this._sampleRenderTarget.texture.magFilter = NearestFilter;
  218. const velocityTarget = this._sampleRenderTarget.texture.clone();
  219. velocityTarget.isRenderTargetTexture = true;
  220. velocityTarget.name = 'velocity';
  221. this._sampleRenderTarget.textures.push( velocityTarget ); // for MRT
  222. }
  223. // textures
  224. const historyTexture = texture( this._historyRenderTarget.texture );
  225. const sampleTexture = texture( this._sampleRenderTarget.textures[ 0 ] );
  226. const velocityTexture = texture( this._sampleRenderTarget.textures[ 1 ] );
  227. const depthTexture = texture( this._sampleRenderTarget.depthTexture );
  228. const resolve = Fn( () => {
  229. const uvNode = uv();
  230. const minColor = vec4( 10000 ).toVar();
  231. const maxColor = vec4( - 10000 ).toVar();
  232. const closestDepth = float( 1 ).toVar();
  233. const closestDepthPixelPosition = vec2( 0 ).toVar();
  234. // sample a 3x3 neighborhood to create a box in color space
  235. // clamping the history color with the resulting min/max colors mitigates ghosting
  236. Loop( { start: int( - 1 ), end: int( 1 ), type: 'int', condition: '<=', name: 'x' }, ( { x } ) => {
  237. Loop( { start: int( - 1 ), end: int( 1 ), type: 'int', condition: '<=', name: 'y' }, ( { y } ) => {
  238. const uvNeighbor = uvNode.add( vec2( float( x ), float( y ) ).mul( this._invSize ) ).toVar();
  239. const colorNeighbor = max( vec4( 0 ), sampleTexture.sample( uvNeighbor ) ).toVar(); // use max() to avoid propagate garbage values
  240. minColor.assign( min( minColor, colorNeighbor ) );
  241. maxColor.assign( max( maxColor, colorNeighbor ) );
  242. const currentDepth = depthTexture.sample( uvNeighbor ).r.toVar();
  243. // find the sample position of the closest depth in the neighborhood (used for velocity)
  244. If( currentDepth.lessThan( closestDepth ), () => {
  245. closestDepth.assign( currentDepth );
  246. closestDepthPixelPosition.assign( uvNeighbor );
  247. } );
  248. } );
  249. } );
  250. // sampling/reprojection
  251. const offset = velocityTexture.sample( closestDepthPixelPosition ).xy.mul( vec2( 0.5, - 0.5 ) ); // NDC to uv offset
  252. const currentColor = sampleTexture.sample( uvNode );
  253. const historyColor = historyTexture.sample( uvNode.sub( offset ) );
  254. // clamping
  255. const clampedHistoryColor = clamp( historyColor, minColor, maxColor );
  256. // flicker reduction based on luminance weighing
  257. const currentWeight = float( 0.05 ).toVar();
  258. const historyWeight = currentWeight.oneMinus().toVar();
  259. const compressedCurrent = currentColor.mul( float( 1 ).div( ( max( max( currentColor.r, currentColor.g ), currentColor.b ).add( 1.0 ) ) ) );
  260. const compressedHistory = clampedHistoryColor.mul( float( 1 ).div( ( max( max( clampedHistoryColor.r, clampedHistoryColor.g ), clampedHistoryColor.b ).add( 1.0 ) ) ) );
  261. const luminanceCurrent = luminance( compressedCurrent.rgb );
  262. const luminanceHistory = luminance( compressedHistory.rgb );
  263. currentWeight.mulAssign( float( 1.0 ).div( luminanceCurrent.add( 1 ) ) );
  264. historyWeight.mulAssign( float( 1.0 ).div( luminanceHistory.add( 1 ) ) );
  265. return add( currentColor.mul( currentWeight ), clampedHistoryColor.mul( historyWeight ) ).div( max( currentWeight.add( historyWeight ), 0.00001 ) );
  266. } );
  267. // materials
  268. this._resolveMaterial.fragmentNode = resolve();
  269. return super.setup( builder );
  270. }
  271. /**
  272. * Frees internal resources. This method should be called
  273. * when the effect is no longer required.
  274. */
  275. dispose() {
  276. super.dispose();
  277. if ( this._sampleRenderTarget !== null ) {
  278. this._sampleRenderTarget.dispose();
  279. this._historyRenderTarget.dispose();
  280. }
  281. this._resolveMaterial.dispose();
  282. }
  283. }
  284. export default TRAAPassNode;
  285. // These jitter vectors are specified in integers because it is easier.
  286. // I am assuming a [-8,8) integer grid, but it needs to be mapped onto [-0.5,0.5)
  287. // before being used, thus these integers need to be scaled by 1/16.
  288. //
  289. // Sample patterns reference: https://msdn.microsoft.com/en-us/library/windows/desktop/ff476218%28v=vs.85%29.aspx?f=255&MSPPError=-2147217396
  290. const _JitterVectors = [
  291. [ - 4, - 7 ], [ - 7, - 5 ], [ - 3, - 5 ], [ - 5, - 4 ],
  292. [ - 1, - 4 ], [ - 2, - 2 ], [ - 6, - 1 ], [ - 4, 0 ],
  293. [ - 7, 1 ], [ - 1, 2 ], [ - 6, 3 ], [ - 3, 3 ],
  294. [ - 7, 6 ], [ - 3, 6 ], [ - 5, 7 ], [ - 1, 7 ],
  295. [ 5, - 7 ], [ 1, - 6 ], [ 6, - 5 ], [ 4, - 4 ],
  296. [ 2, - 3 ], [ 7, - 2 ], [ 1, - 1 ], [ 4, - 1 ],
  297. [ 2, 1 ], [ 6, 2 ], [ 0, 4 ], [ 4, 4 ],
  298. [ 2, 5 ], [ 7, 5 ], [ 5, 6 ], [ 3, 7 ]
  299. ];
  300. /**
  301. * TSL function for creating a TRAA pass node for Temporal Reprojection Anti-Aliasing.
  302. *
  303. * @tsl
  304. * @function
  305. * @param {Scene} scene - The scene to render.
  306. * @param {Camera} camera - The camera to render the scene with.
  307. * @returns {TRAAPassNode}
  308. */
  309. export const traaPass = ( scene, camera ) => nodeObject( new TRAAPassNode( scene, camera ) );